Implement value type abstraction for types.

This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 219372163
This commit is contained in:
River Riddle 2018-10-30 14:59:22 -07:00 committed by jpienaar
parent 75376b8e33
commit 4c465a181d
41 changed files with 998 additions and 811 deletions

View File

@ -51,7 +51,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
/// whether indices[dim] is independent of the value `input`. /// whether indices[dim] is independent of the value `input`.
// For now we assume no layout map or identity layout map in the MemRef. // For now we assume no layout map or identity layout map in the MemRef.
// TODO(ntv): support more than identity layout map. // TODO(ntv): support more than identity layout map.
bool isAccessInvariant(const MLValue &input, MemRefType *memRefType, bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
llvm::ArrayRef<MLValue *> indices, unsigned dim); llvm::ArrayRef<MLValue *> indices, unsigned dim);
/// Checks whether all the LoadOp and StoreOp matched have access indexing /// Checks whether all the LoadOp and StoreOp matched have access indexing

View File

@ -250,9 +250,9 @@ public:
TypeAttr() = default; TypeAttr() = default;
/* implicit */ TypeAttr(Attribute::ImplType *ptr); /* implicit */ TypeAttr(Attribute::ImplType *ptr);
static TypeAttr get(Type *type, MLIRContext *context); static TypeAttr get(Type type, MLIRContext *context);
Type *getValue() const; Type getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Type; } static bool kindof(Kind kind) { return kind == Kind::Type; }
@ -277,7 +277,7 @@ public:
Function *getValue() const; Function *getValue() const;
FunctionType *getType() const; FunctionType getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Function; } static bool kindof(Kind kind) { return kind == Kind::Function; }
@ -294,7 +294,7 @@ public:
ElementsAttr() = default; ElementsAttr() = default;
/* implicit */ ElementsAttr(Attribute::ImplType *ptr); /* implicit */ ElementsAttr(Attribute::ImplType *ptr);
VectorOrTensorType *getType() const; VectorOrTensorType getType() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { static bool kindof(Kind kind) {
@ -313,7 +313,7 @@ public:
SplatElementsAttr() = default; SplatElementsAttr() = default;
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr); /* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt); static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
Attribute getValue() const; Attribute getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
@ -335,12 +335,12 @@ public:
/// width specified by the element type (note all float type are 64 bits). /// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend /// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary. /// to 64 bits if necessary.
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data); static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them // TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the // to a character array. Then call the above method to construct the
// attribute. // attribute.
static DenseElementsAttr get(VectorOrTensorType *type, static DenseElementsAttr get(VectorOrTensorType type,
ArrayRef<Attribute> values); ArrayRef<Attribute> values);
void getValues(SmallVectorImpl<Attribute> &values) const; void getValues(SmallVectorImpl<Attribute> &values) const;
@ -410,7 +410,7 @@ public:
OpaqueElementsAttr() = default; OpaqueElementsAttr() = default;
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr); /* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes); static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
StringRef getValue() const; StringRef getValue() const;
@ -440,7 +440,7 @@ public:
SparseElementsAttr() = default; SparseElementsAttr() = default;
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr); /* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
static SparseElementsAttr get(VectorOrTensorType *type, static SparseElementsAttr get(VectorOrTensorType type,
DenseIntElementsAttr indices, DenseIntElementsAttr indices,
DenseElementsAttr values); DenseElementsAttr values);

View File

@ -64,10 +64,10 @@ public:
bool args_empty() const { return arguments.empty(); } bool args_empty() const { return arguments.empty(); }
/// Add one value to the operand list. /// Add one value to the operand list.
BBArgument *addArgument(Type *type); BBArgument *addArgument(Type type);
/// Add one argument to the argument list for each type specified in the list. /// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types); llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
unsigned getNumArguments() const { return arguments.size(); } unsigned getNumArguments() const { return arguments.size(); }
BBArgument *getArgument(unsigned i) { return arguments[i]; } BBArgument *getArgument(unsigned i) { return arguments[i]; }

View File

@ -68,29 +68,28 @@ public:
unsigned column); unsigned column);
// Types. // Types.
FloatType *getBF16Type(); FloatType getBF16Type();
FloatType *getF16Type(); FloatType getF16Type();
FloatType *getF32Type(); FloatType getF32Type();
FloatType *getF64Type(); FloatType getF64Type();
OtherType *getIndexType(); OtherType getIndexType();
OtherType *getTFControlType(); OtherType getTFControlType();
OtherType *getTFStringType(); OtherType getTFStringType();
OtherType *getTFResourceType(); OtherType getTFResourceType();
OtherType *getTFVariantType(); OtherType getTFVariantType();
OtherType *getTFComplex64Type(); OtherType getTFComplex64Type();
OtherType *getTFComplex128Type(); OtherType getTFComplex128Type();
OtherType *getTFF32REFType(); OtherType getTFF32REFType();
IntegerType *getIntegerType(unsigned width); IntegerType getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs, FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
ArrayRef<Type *> results); MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType, ArrayRef<AffineMap> affineMapComposition = {},
ArrayRef<AffineMap> affineMapComposition = {}, unsigned memorySpace = 0);
unsigned memorySpace = 0); VectorType getVectorType(ArrayRef<int> shape, Type elementType);
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType); RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType); UnrankedTensorType getTensorType(Type elementType);
UnrankedTensorType *getTensorType(Type *elementType);
// Attributes. // Attributes.
@ -102,15 +101,15 @@ public:
ArrayAttr getArrayAttr(ArrayRef<Attribute> value); ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map); AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set); IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type *type); TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(const Function *value); FunctionAttr getFunctionAttr(const Function *value);
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt); ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type, ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data); ArrayRef<char> data);
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type, ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices, DenseIntElementsAttr indices,
DenseElementsAttr values); DenseElementsAttr values);
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes); ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
// Affine expressions and affine maps. // Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineDimExpr(unsigned position);
@ -366,7 +365,7 @@ public:
/// Creates an operation given the fields. /// Creates an operation given the fields.
OperationStmt *createOperation(Location *location, OperationName name, OperationStmt *createOperation(Location *location, OperationName name,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<Type *> types, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs); ArrayRef<NamedAttribute> attrs);
/// Create operation of specific op type at the current insertion point. /// Create operation of specific op type at the current insertion point.

View File

@ -96,7 +96,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
public: public:
/// Builds a constant op with the specified attribute value and result type. /// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Attribute value, static void build(Builder *builder, OperationState *result, Attribute value,
Type *type); Type type);
Attribute getValue() const { return getAttr("value"); } Attribute getValue() const { return getAttr("value"); }
@ -123,7 +123,7 @@ class ConstantFloatOp : public ConstantOp {
public: public:
/// Builds a constant float op producing a float of the specified type. /// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType *type); const APFloat &value, FloatType type);
APFloat getValue() const { APFloat getValue() const {
return getAttrOfType<FloatAttr>("value").getValue(); return getAttrOfType<FloatAttr>("value").getValue();
@ -150,7 +150,7 @@ public:
/// Build a constant int op producing an integer with the specified type, /// Build a constant int op producing an integer with the specified type,
/// which must be an integer type. /// which must be an integer type.
static void build(Builder *builder, OperationState *result, int64_t value, static void build(Builder *builder, OperationState *result, int64_t value,
Type *type); Type type);
int64_t getValue() const { int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getValue(); return getAttrOfType<IntegerAttr>("value").getValue();

View File

@ -27,7 +27,7 @@ namespace mlir {
// blocks, each of which includes instructions. // blocks, each of which includes instructions.
class CFGFunction : public Function { class CFGFunction : public Function {
public: public:
CFGFunction(Location *location, StringRef name, FunctionType *type, CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}); ArrayRef<NamedAttribute> attrs = {});
~CFGFunction(); ~CFGFunction();

View File

@ -66,7 +66,7 @@ public:
} }
protected: protected:
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {} CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
}; };
/// Basic block arguments are CFG Values. /// Basic block arguments are CFG Values.
@ -87,7 +87,7 @@ public:
private: private:
friend class BasicBlock; // For access to private constructor. friend class BasicBlock; // For access to private constructor.
BBArgument(Type *type, BasicBlock *owner) BBArgument(Type type, BasicBlock *owner)
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {} : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
/// The owner of this operand. /// The owner of this operand.
@ -99,7 +99,7 @@ private:
/// Instruction results are CFG Values. /// Instruction results are CFG Values.
class InstResult : public CFGValue { class InstResult : public CFGValue {
public: public:
InstResult(Type *type, OperationInst *owner) InstResult(Type type, OperationInst *owner)
: CFGValue(CFGValueKind::InstResult, type), owner(owner) {} : CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
static bool classof(const SSAValue *value) { static bool classof(const SSAValue *value) {

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist.h"
@ -55,7 +56,7 @@ public:
Identifier getName() const { return nameAndKind.getPointer(); } Identifier getName() const { return nameAndKind.getPointer(); }
/// Return the type of this function. /// Return the type of this function.
FunctionType *getType() const { return type; } FunctionType getType() const { return type; }
/// Returns all of the attributes on this function. /// Returns all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() const; ArrayRef<NamedAttribute> getAttrs() const;
@ -93,7 +94,7 @@ public:
void emitNote(const Twine &message) const; void emitNote(const Twine &message) const;
protected: protected:
Function(Kind kind, Location *location, StringRef name, FunctionType *type, Function(Kind kind, Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}); ArrayRef<NamedAttribute> attrs = {});
~Function(); ~Function();
@ -108,7 +109,7 @@ private:
Location *location; Location *location;
/// The type of the function. /// The type of the function.
FunctionType *const type; FunctionType type;
/// This holds general named attributes for the function. /// This holds general named attributes for the function.
AttributeListStorage *attrs; AttributeListStorage *attrs;
@ -121,7 +122,7 @@ private:
/// defined in some other module. /// defined in some other module.
class ExtFunction : public Function { class ExtFunction : public Function {
public: public:
ExtFunction(Location *location, StringRef name, FunctionType *type, ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}); ArrayRef<NamedAttribute> attrs = {});
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.

View File

@ -202,7 +202,7 @@ public:
/// Create a new OperationInst with the specified fields. /// Create a new OperationInst with the specified fields.
static OperationInst *create(Location *location, OperationName name, static OperationInst *create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands, ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes, ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
MLIRContext *context); MLIRContext *context);

View File

@ -41,7 +41,7 @@ class MLFunction final
public: public:
/// Creates a new MLFunction with the specific type. /// Creates a new MLFunction with the specific type.
static MLFunction *create(Location *location, StringRef name, static MLFunction *create(Location *location, StringRef name,
FunctionType *type, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}); ArrayRef<NamedAttribute> attrs = {});
/// Destroys this statement and its subclass data. /// Destroys this statement and its subclass data.
@ -52,7 +52,7 @@ public:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Returns number of arguments. /// Returns number of arguments.
unsigned getNumArguments() const { return getType()->getInputs().size(); } unsigned getNumArguments() const { return getType().getInputs().size(); }
/// Gets argument. /// Gets argument.
MLFuncArgument *getArgument(unsigned idx) { MLFuncArgument *getArgument(unsigned idx) {
@ -103,13 +103,13 @@ public:
} }
private: private:
MLFunction(Location *location, StringRef name, FunctionType *type, MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}); ArrayRef<NamedAttribute> attrs = {});
// This stuff is used by the TrailingObjects template. // This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>; friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const { size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
return getType()->getInputs().size(); return getType().getInputs().size();
} }
// Internal functions to get argument list used by getArgument() methods. // Internal functions to get argument list used by getArgument() methods.

View File

@ -73,7 +73,7 @@ public:
} }
protected: protected:
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {} MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
}; };
/// This is the value defined by an argument of an ML function. /// This is the value defined by an argument of an ML function.
@ -93,7 +93,7 @@ public:
private: private:
friend class MLFunction; // For access to private constructor. friend class MLFunction; // For access to private constructor.
MLFuncArgument(Type *type, MLFunction *owner) MLFuncArgument(Type type, MLFunction *owner)
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {} : MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
/// The owner of this operand. /// The owner of this operand.
@ -105,7 +105,7 @@ private:
/// This is a value defined by a result of an operation instruction. /// This is a value defined by a result of an operation instruction.
class StmtResult : public MLValue { class StmtResult : public MLValue {
public: public:
StmtResult(Type *type, OperationStmt *owner) StmtResult(Type type, OperationStmt *owner)
: MLValue(MLValueKind::StmtResult, type), owner(owner) {} : MLValue(MLValueKind::StmtResult, type), owner(owner) {}
static bool classof(const SSAValue *value) { static bool classof(const SSAValue *value) {

View File

@ -71,13 +71,13 @@ struct constant_int_op_binder {
bool match(Operation *op) { bool match(Operation *op) {
if (auto constOp = op->dyn_cast<ConstantOp>()) { if (auto constOp = op->dyn_cast<ConstantOp>()) {
auto *type = constOp->getResult()->getType(); auto type = constOp->getResult()->getType();
auto attr = constOp->getAttr("value"); auto attr = constOp->getAttr("value");
if (isa<IntegerType>(type)) { if (type.isa<IntegerType>()) {
return attr_value_binder<IntegerAttr>(bind_value).match(attr); return attr_value_binder<IntegerAttr>(bind_value).match(attr);
} }
if (isa<VectorOrTensorType>(type)) { if (type.isa<VectorOrTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value) return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getValue()); .match(splatAttr.getValue());

View File

@ -493,7 +493,7 @@ public:
return this->getOperation()->getResult(0); return this->getOperation()->getResult(0);
} }
Type *getType() const { return getResult()->getType(); } Type getType() const { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in /// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns /// the IR that uses 'this' to use the other value instead. When this returns
@ -539,7 +539,7 @@ public:
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
Type *getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) const { return getResult(i)->getType(); }
static bool verifyTrait(const Operation *op) { static bool verifyTrait(const Operation *op) {
return impl::verifyNResults(op, N); return impl::verifyNResults(op, N);
@ -565,7 +565,7 @@ public:
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
Type *getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) const { return getResult(i)->getType(); }
static bool verifyTrait(const Operation *op) { static bool verifyTrait(const Operation *op) {
return impl::verifyAtLeastNResults(op, N); return impl::verifyAtLeastNResults(op, N);
@ -803,7 +803,7 @@ protected:
// which avoids them being template instantiated/duplicated. // which avoids them being template instantiated/duplicated.
namespace impl { namespace impl {
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source, void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
Type *destType); Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result); bool parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(const Operation *op, OpAsmPrinter *p); void printCastOp(const Operation *op, OpAsmPrinter *p);
} // namespace impl } // namespace impl
@ -819,7 +819,7 @@ class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...> { OpTrait::HasNoSideEffect, Traits...> {
public: public:
static void build(Builder *builder, OperationState *result, SSAValue *source, static void build(Builder *builder, OperationState *result, SSAValue *source,
Type *destType) { Type destType) {
impl::buildCastOp(builder, result, source, destType); impl::buildCastOp(builder, result, source, destType);
} }
static bool parse(OpAsmParser *parser, OperationState *result) { static bool parse(OpAsmParser *parser, OperationState *result) {

View File

@ -67,7 +67,7 @@ public:
printOperand(*it); printOperand(*it);
} }
} }
virtual void printType(const Type *type) = 0; virtual void printType(Type type) = 0;
virtual void printFunctionReference(const Function *func) = 0; virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(Attribute attr) = 0; virtual void printAttribute(Attribute attr) = 0;
virtual void printAffineMap(AffineMap map) = 0; virtual void printAffineMap(AffineMap map) = 0;
@ -95,8 +95,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
return p; return p;
} }
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) { inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(&type); p.printType(type);
return p; return p;
} }
@ -163,20 +163,20 @@ public:
virtual bool parseComma() = 0; virtual bool parseComma() = 0;
/// Parse a colon followed by a type. /// Parse a colon followed by a type.
virtual bool parseColonType(Type *&result) = 0; virtual bool parseColonType(Type &result) = 0;
/// Parse a type of a specific kind, e.g. a FunctionType. /// Parse a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> bool parseColonType(TypeType *&result) { template <typename TypeType> bool parseColonType(TypeType &result) {
llvm::SMLoc loc; llvm::SMLoc loc;
getCurrentLocation(&loc); getCurrentLocation(&loc);
// Parse any kind of type. // Parse any kind of type.
Type *type; Type type;
if (parseColonType(type)) if (parseColonType(type))
return true; return true;
// Check for the right kind of attribute. // Check for the right kind of attribute.
result = dyn_cast<TypeType>(type); result = type.dyn_cast<TypeType>();
if (!result) { if (!result) {
emitError(loc, "invalid kind of type specified"); emitError(loc, "invalid kind of type specified");
return true; return true;
@ -186,15 +186,15 @@ public:
} }
/// Parse a colon followed by a type list, which must have at least one type. /// Parse a colon followed by a type list, which must have at least one type.
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0; virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type. /// Parse a keyword followed by a type.
virtual bool parseKeywordType(const char *keyword, Type *&result) = 0; virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
/// Add the specified type to the end of the specified type list and return /// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and /// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators. /// chain through || operators.
bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) { bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type); result.push_back(type);
return false; return false;
} }
@ -202,7 +202,7 @@ public:
/// Add the specified types to the end of the specified type list and return /// Add the specified types to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and /// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators. /// chain through || operators.
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) { bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end()); result.append(types.begin(), types.end());
return false; return false;
} }
@ -288,13 +288,13 @@ public:
/// Resolve an operand to an SSA value, emitting an error and returning true /// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure. /// on failure.
virtual bool resolveOperand(const OperandType &operand, Type *type, virtual bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) = 0; SmallVectorImpl<SSAValue *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning /// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success. /// true on failure, or appending the results to the list on success.
/// This method should be used when all operands have the same type. /// This method should be used when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type, virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<SSAValue *> &result) { SmallVectorImpl<SSAValue *> &result) {
for (auto elt : operands) for (auto elt : operands)
if (resolveOperand(elt, type, result)) if (resolveOperand(elt, type, result))
@ -306,7 +306,7 @@ public:
/// emitting an error and returning true on failure, or appending the results /// emitting an error and returning true on failure, or appending the results
/// to the list on success. /// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands, virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type *> types, llvm::SMLoc loc, ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<SSAValue *> &result) { SmallVectorImpl<SSAValue *> &result) {
if (operands.size() != types.size()) if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) + return emitError(loc, Twine(operands.size()) +
@ -321,7 +321,7 @@ public:
} }
/// Resolve a parse function name and a type into a function reference. /// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type, virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) = 0; llvm::SMLoc loc, Function *&result) = 0;
/// Emit a diagnostic at the specified location and return true. /// Emit a diagnostic at the specified location and return true.

View File

@ -25,6 +25,7 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/PointerUnion.h"
#include <memory> #include <memory>
@ -191,7 +192,7 @@ struct OperationState {
OperationName name; OperationName name;
SmallVector<SSAValue *, 4> operands; SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation. /// Types of the results of this operation.
SmallVector<Type *, 4> types; SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes; SmallVector<NamedAttribute, 4> attributes;
public: public:
@ -202,7 +203,7 @@ public:
: context(context), location(location), name(name) {} : context(context), location(location), name(name) {}
OperationState(MLIRContext *context, Location *location, StringRef name, OperationState(MLIRContext *context, Location *location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types, ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attributes = {}) ArrayRef<NamedAttribute> attributes = {})
: context(context), location(location), name(name, context), : context(context), location(location), name(name, context),
operands(operands.begin(), operands.end()), operands(operands.begin(), operands.end()),
@ -213,7 +214,7 @@ public:
operands.append(newOperands.begin(), newOperands.end()); operands.append(newOperands.begin(), newOperands.end());
} }
void addTypes(ArrayRef<Type *> newTypes) { void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end()); types.append(newTypes.begin(), newTypes.end());
} }

View File

@ -25,7 +25,6 @@
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h" #include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerIntPair.h"
namespace mlir { namespace mlir {
class Function; class Function;
@ -51,7 +50,7 @@ public:
SSAValueKind getKind() const { return typeAndKind.getInt(); } SSAValueKind getKind() const { return typeAndKind.getInt(); }
Type *getType() const { return typeAndKind.getPointer(); } Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in /// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns /// the IR that uses 'this' to use the other value instead. When this returns
@ -93,9 +92,10 @@ public:
void dump() const; void dump() const;
protected: protected:
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {} SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
private: private:
const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind; const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
}; };
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) { inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
@ -127,7 +127,7 @@ public:
inline use_range getUses() const; inline use_range getUses() const;
protected: protected:
SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {} SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
}; };
// Utility functions for iterating through SSAValue uses. // Utility functions for iterating through SSAValue uses.

View File

@ -44,7 +44,7 @@ public:
/// Create a new OperationStmt with the specific fields. /// Create a new OperationStmt with the specific fields.
static OperationStmt *create(Location *location, OperationName name, static OperationStmt *create(Location *location, OperationName name,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes, ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
MLIRContext *context); MLIRContext *context);
@ -329,7 +329,7 @@ public:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Return the context this operation is associated with. /// Return the context this operation is associated with.
MLIRContext *getContext() const { return getType()->getContext(); } MLIRContext *getContext() const { return getType().getContext(); }
using Statement::dump; using Statement::dump;
using Statement::print; using Statement::print;

View File

@ -20,6 +20,7 @@
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir { namespace mlir {
class AffineMap; class AffineMap;
@ -28,6 +29,22 @@ class IntegerType;
class FloatType; class FloatType;
class OtherType; class OtherType;
namespace detail {
class TypeStorage;
class IntegerTypeStorage;
class FloatTypeStorage;
struct OtherTypeStorage;
struct FunctionTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
/// Instances of the Type class are immutable, uniqued, immortal, and owned by /// Instances of the Type class are immutable, uniqued, immortal, and owned by
/// MLIRContext. As such, they are passed around by raw non-const pointer. /// MLIRContext. As such, they are passed around by raw non-const pointer.
/// ///
@ -68,11 +85,34 @@ public:
MemRef, MemRef,
}; };
using ImplType = detail::TypeStorage;
Type() : type(nullptr) {}
/* implicit */ Type(const ImplType *type)
: type(const_cast<ImplType *>(type)) {}
Type(const Type &other) : type(other.type) {}
Type &operator=(Type other) {
type = other.type;
return *this;
}
bool operator==(Type other) const { return type == other.type; }
bool operator!=(Type other) const { return !(*this == other); }
explicit operator bool() const { return type; }
bool operator!() const { return type == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
/// Return the classification for this type. /// Return the classification for this type.
Kind getKind() const { return kind; } Kind getKind() const;
/// Return the LLVMContext in which this type was uniqued. /// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const { return context; } MLIRContext *getContext() const;
// Convenience predicates. This is only for 'other' and floating point types, // Convenience predicates. This is only for 'other' and floating point types,
// derived types should use isa/dyn_cast. // derived types should use isa/dyn_cast.
@ -97,56 +137,42 @@ public:
unsigned getBitWidth() const; unsigned getBitWidth() const;
// Convenience factories. // Convenience factories.
static IntegerType *getInteger(unsigned width, MLIRContext *ctx); static IntegerType getInteger(unsigned width, MLIRContext *ctx);
static FloatType *getBF16(MLIRContext *ctx); static FloatType getBF16(MLIRContext *ctx);
static FloatType *getF16(MLIRContext *ctx); static FloatType getF16(MLIRContext *ctx);
static FloatType *getF32(MLIRContext *ctx); static FloatType getF32(MLIRContext *ctx);
static FloatType *getF64(MLIRContext *ctx); static FloatType getF64(MLIRContext *ctx);
static OtherType *getIndex(MLIRContext *ctx); static OtherType getIndex(MLIRContext *ctx);
static OtherType *getTFControl(MLIRContext *ctx); static OtherType getTFControl(MLIRContext *ctx);
static OtherType *getTFString(MLIRContext *ctx); static OtherType getTFString(MLIRContext *ctx);
static OtherType *getTFResource(MLIRContext *ctx); static OtherType getTFResource(MLIRContext *ctx);
static OtherType *getTFVariant(MLIRContext *ctx); static OtherType getTFVariant(MLIRContext *ctx);
static OtherType *getTFComplex64(MLIRContext *ctx); static OtherType getTFComplex64(MLIRContext *ctx);
static OtherType *getTFComplex128(MLIRContext *ctx); static OtherType getTFComplex128(MLIRContext *ctx);
static OtherType *getTFF32REF(MLIRContext *ctx); static OtherType getTFF32REF(MLIRContext *ctx);
/// Print the current type. /// Print the current type.
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
void dump() const; void dump() const;
friend ::llvm::hash_code hash_value(Type arg);
unsigned getSubclassData() const;
void setSubclassData(unsigned val);
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
return static_cast<const void *>(type);
}
static Type getFromOpaquePointer(const void *pointer) {
return Type((ImplType *)(pointer));
}
protected: protected:
explicit Type(Kind kind, MLIRContext *context) ImplType *type;
: context(context), kind(kind), subclassData(0) {}
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
: Type(kind, context) {
setSubclassData(subClassData);
}
~Type() {}
unsigned getSubclassData() const { return subclassData; }
void setSubclassData(unsigned val) {
subclassData = val;
// Ensure we don't have any accidental truncation.
assert(getSubclassData() == val && "Subclass data too large for field");
}
private:
Type(const Type&) = delete;
void operator=(const Type&) = delete;
/// This refers to the MLIRContext in which this type was uniqued.
MLIRContext *const context;
/// Classification of the subclass, used for type checking.
Kind kind : 8;
// Space for subclasses to store data.
unsigned subclassData : 24;
}; };
inline raw_ostream &operator<<(raw_ostream &os, const Type &type) { inline raw_ostream &operator<<(raw_ostream &os, Type type) {
type.print(os); type.print(os);
return os; return os;
} }
@ -154,148 +180,138 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
/// Integer types can have arbitrary bitwidth up to a large fixed limit. /// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType : public Type { class IntegerType : public Type {
public: public:
static IntegerType *get(unsigned width, MLIRContext *context); using ImplType = detail::IntegerTypeStorage;
IntegerType() = default;
/* implicit */ IntegerType(Type::ImplType *ptr);
static IntegerType get(unsigned width, MLIRContext *context);
/// Return the bitwidth of this integer type. /// Return the bitwidth of this integer type.
unsigned getWidth() const { unsigned getWidth() const;
return width;
}
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) { return kind == Kind::Integer; }
return type->getKind() == Kind::Integer;
}
/// Integer representation maximal bitwidth. /// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096; static constexpr unsigned kMaxWidth = 4096;
private:
unsigned width;
IntegerType(unsigned width, MLIRContext *context);
~IntegerType() = delete;
}; };
inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) { inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx); return IntegerType::get(width, ctx);
} }
/// Return true if this is an integer type with the specified width. /// Return true if this is an integer type with the specified width.
inline bool Type::isInteger(unsigned width) const { inline bool Type::isInteger(unsigned width) const {
if (auto *intTy = dyn_cast<IntegerType>(this)) if (auto intTy = dyn_cast<IntegerType>())
return intTy->getWidth() == width; return intTy.getWidth() == width;
return false; return false;
} }
class FloatType : public Type { class FloatType : public Type {
public: public:
using ImplType = detail::FloatTypeStorage;
FloatType() = default;
/* implicit */ FloatType(Type::ImplType *ptr);
static FloatType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) {
return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE && return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE; kind <= Kind::LAST_FLOATING_POINT_TYPE;
} }
static FloatType *get(Kind kind, MLIRContext *context);
private:
FloatType(Kind kind, MLIRContext *context);
~FloatType() = delete;
}; };
inline FloatType *Type::getBF16(MLIRContext *ctx) { inline FloatType Type::getBF16(MLIRContext *ctx) {
return FloatType::get(Kind::BF16, ctx); return FloatType::get(Kind::BF16, ctx);
} }
inline FloatType *Type::getF16(MLIRContext *ctx) { inline FloatType Type::getF16(MLIRContext *ctx) {
return FloatType::get(Kind::F16, ctx); return FloatType::get(Kind::F16, ctx);
} }
inline FloatType *Type::getF32(MLIRContext *ctx) { inline FloatType Type::getF32(MLIRContext *ctx) {
return FloatType::get(Kind::F32, ctx); return FloatType::get(Kind::F32, ctx);
} }
inline FloatType *Type::getF64(MLIRContext *ctx) { inline FloatType Type::getF64(MLIRContext *ctx) {
return FloatType::get(Kind::F64, ctx); return FloatType::get(Kind::F64, ctx);
} }
/// This is a type for the random collection of special base types. /// This is a type for the random collection of special base types.
class OtherType : public Type { class OtherType : public Type {
public: public:
/// Methods for support type inquiry through isa, cast, and dyn_cast. using ImplType = detail::OtherTypeStorage;
static bool classof(const Type *type) { OtherType() = default;
return type->getKind() >= Kind::FIRST_OTHER_TYPE && /* implicit */ OtherType(Type::ImplType *ptr);
type->getKind() <= Kind::LAST_OTHER_TYPE;
}
static OtherType *get(Kind kind, MLIRContext *context);
private: static OtherType get(Kind kind, MLIRContext *context);
OtherType(Kind kind, MLIRContext *context);
~OtherType() = delete; /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
}
}; };
inline OtherType *Type::getIndex(MLIRContext *ctx) { inline OtherType Type::getIndex(MLIRContext *ctx) {
return OtherType::get(Kind::Index, ctx); return OtherType::get(Kind::Index, ctx);
} }
inline OtherType *Type::getTFControl(MLIRContext *ctx) { inline OtherType Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx); return OtherType::get(Kind::TFControl, ctx);
} }
inline OtherType *Type::getTFResource(MLIRContext *ctx) { inline OtherType Type::getTFResource(MLIRContext *ctx) {
return OtherType::get(Kind::TFResource, ctx); return OtherType::get(Kind::TFResource, ctx);
} }
inline OtherType *Type::getTFString(MLIRContext *ctx) { inline OtherType Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx); return OtherType::get(Kind::TFString, ctx);
} }
inline OtherType *Type::getTFVariant(MLIRContext *ctx) { inline OtherType Type::getTFVariant(MLIRContext *ctx) {
return OtherType::get(Kind::TFVariant, ctx); return OtherType::get(Kind::TFVariant, ctx);
} }
inline OtherType *Type::getTFComplex64(MLIRContext *ctx) { inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex64, ctx); return OtherType::get(Kind::TFComplex64, ctx);
} }
inline OtherType *Type::getTFComplex128(MLIRContext *ctx) { inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex128, ctx); return OtherType::get(Kind::TFComplex128, ctx);
} }
inline OtherType *Type::getTFF32REF(MLIRContext *ctx) { inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
return OtherType::get(Kind::TFF32REF, ctx); return OtherType::get(Kind::TFF32REF, ctx);
} }
/// Function types map from a list of inputs to a list of results. /// Function types map from a list of inputs to a list of results.
class FunctionType : public Type { class FunctionType : public Type {
public: public:
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results, using ImplType = detail::FunctionTypeStorage;
MLIRContext *context); FunctionType() = default;
/* implicit */ FunctionType(Type::ImplType *ptr);
static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
MLIRContext *context);
// Input types. // Input types.
unsigned getNumInputs() const { return getSubclassData(); } unsigned getNumInputs() const { return getSubclassData(); }
Type *getInput(unsigned i) const { return getInputs()[i]; } Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type*> getInputs() const { ArrayRef<Type> getInputs() const;
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
}
// Result types. // Result types.
unsigned getNumResults() const { return numResults; } unsigned getNumResults() const;
Type *getResult(unsigned i) const { return getResults()[i]; } Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type*> getResults() const { ArrayRef<Type> getResults() const;
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) { return kind == Kind::Function; }
return type->getKind() == Kind::Function;
}
private:
unsigned numResults;
Type *const *inputsAndResults;
FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context);
~FunctionType() = delete;
}; };
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor /// This is a common base class between Vector, UnrankedTensor, and RankedTensor
/// types, because many operations work on values of these aggregate types. /// types, because many operations work on values of these aggregate types.
class VectorOrTensorType : public Type { class VectorOrTensorType : public Type {
public: public:
Type *getElementType() const { return elementType; } using ImplType = detail::VectorOrTensorTypeStorage;
VectorOrTensorType() = default;
/* implicit */ VectorOrTensorType(Type::ImplType *ptr);
Type getElementType() const;
/// If this is ranked tensor or vector type, return the number of elements. If /// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor or vector, abort. /// it is an unranked tensor or vector, abort.
@ -319,56 +335,40 @@ public:
int getDimSize(unsigned i) const; int getDimSize(unsigned i) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) {
return type->getKind() == Kind::Vector || return kind == Kind::Vector || kind == Kind::RankedTensor ||
type->getKind() == Kind::RankedTensor || kind == Kind::UnrankedTensor;
type->getKind() == Kind::UnrankedTensor;
} }
public:
Type *elementType;
VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
unsigned subClassData = 0);
}; };
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension. /// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType { class VectorType : public VectorOrTensorType {
public: public:
static VectorType *get(ArrayRef<int> shape, Type *elementType); using ImplType = detail::VectorTypeStorage;
VectorType() = default;
/* implicit */ VectorType(Type::ImplType *ptr);
ArrayRef<int> getShape() const { static VectorType get(ArrayRef<int> shape, Type elementType);
return ArrayRef<int>(shapeElements, getSubclassData());
} ArrayRef<int> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) { return kind == Kind::Vector; }
return type->getKind() == Kind::Vector;
}
private:
const int *shapeElements;
Type *elementType;
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
~VectorType() = delete;
}; };
/// Tensor types represent multi-dimensional arrays, and have two variants: /// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType. /// RankedTensorType and UnrankedTensorType.
class TensorType : public VectorOrTensorType { class TensorType : public VectorOrTensorType {
public: public:
using ImplType = detail::TensorTypeStorage;
TensorType() = default;
/* implicit */ TensorType(Type::ImplType *ptr);
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) { static bool kindof(Kind kind) {
return type->getKind() == Kind::RankedTensor || return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
type->getKind() == Kind::UnrankedTensor;
} }
protected:
TensorType(Kind kind, Type *elementType, MLIRContext *context);
~TensorType() {}
}; };
/// Ranked tensor types represent multi-dimensional arrays that have a shape /// Ranked tensor types represent multi-dimensional arrays that have a shape
@ -376,40 +376,30 @@ protected:
/// integer or unknown (represented -1). /// integer or unknown (represented -1).
class RankedTensorType : public TensorType { class RankedTensorType : public TensorType {
public: public:
static RankedTensorType *get(ArrayRef<int> shape, using ImplType = detail::RankedTensorTypeStorage;
Type *elementType); RankedTensorType() = default;
/* implicit */ RankedTensorType(Type::ImplType *ptr);
ArrayRef<int> getShape() const { static RankedTensorType get(ArrayRef<int> shape, Type elementType);
return ArrayRef<int>(shapeElements, getSubclassData());
}
static bool classof(const Type *type) { ArrayRef<int> getShape() const;
return type->getKind() == Kind::RankedTensor;
}
private: static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
const int *shapeElements;
RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context);
~RankedTensorType() = delete;
}; };
/// Unranked tensor types represent multi-dimensional arrays that have an /// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape. /// unknown shape.
class UnrankedTensorType : public TensorType { class UnrankedTensorType : public TensorType {
public: public:
static UnrankedTensorType *get(Type *elementType); using ImplType = detail::UnrankedTensorTypeStorage;
UnrankedTensorType() = default;
/* implicit */ UnrankedTensorType(Type::ImplType *ptr);
static UnrankedTensorType get(Type elementType);
ArrayRef<int> getShape() const { return ArrayRef<int>(); } ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool classof(const Type *type) { static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
return type->getKind() == Kind::UnrankedTensor;
}
private:
UnrankedTensorType(Type *elementType, MLIRContext *context);
~UnrankedTensorType() = delete;
}; };
/// MemRef types represent a region of memory that have a shape with a fixed /// MemRef types represent a region of memory that have a shape with a fixed
@ -418,62 +408,96 @@ private:
/// affine map composition, represented as an array AffineMap pointers. /// affine map composition, represented as an array AffineMap pointers.
class MemRefType : public Type { class MemRefType : public Type {
public: public:
using ImplType = detail::MemRefTypeStorage;
MemRefType() = default;
/* implicit */ MemRefType(Type::ImplType *ptr);
/// Get or create a new MemRefType based on shape, element type, affine /// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space. /// map composition, and memory space.
static MemRefType *get(ArrayRef<int> shape, Type *elementType, static MemRefType get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition, ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace); unsigned memorySpace);
unsigned getRank() const { return getShape().size(); } unsigned getRank() const { return getShape().size(); }
/// Returns an array of memref shape dimension sizes. /// Returns an array of memref shape dimension sizes.
ArrayRef<int> getShape() const { ArrayRef<int> getShape() const;
return ArrayRef<int>(shapeElements, getSubclassData());
}
/// Return the size of the specified dimension, or -1 if unspecified. /// Return the size of the specified dimension, or -1 if unspecified.
int getDimSize(unsigned i) const { return getShape()[i]; } int getDimSize(unsigned i) const { return getShape()[i]; }
/// Returns the elemental type for this memref shape. /// Returns the elemental type for this memref shape.
Type *getElementType() const { return elementType; } Type getElementType() const;
/// Returns an array of affine map pointers representing the memref affine /// Returns an array of affine map pointers representing the memref affine
/// map composition. /// map composition.
ArrayRef<AffineMap> getAffineMaps() const; ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides. /// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const { return memorySpace; } unsigned getMemorySpace() const;
/// Returns the number of dimensions with dynamic size. /// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const; unsigned getNumDynamicDims() const;
static bool classof(const Type *type) { static bool kindof(Kind kind) { return kind == Kind::MemRef; }
return type->getKind() == Kind::MemRef;
}
private:
/// The type of each scalar element of the memref.
Type *elementType;
/// An array of integers which stores the shape dimension sizes.
const int *shapeElements;
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
AffineMap const *affineMapList;
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context);
~MemRefType() = delete;
}; };
/// Return true if the specified element type is ok in a tensor. // Make Type hashable.
static bool isValidTensorElementType(Type *type) { inline ::llvm::hash_code hash_value(Type arg) {
return isa<FloatType>(type) || isa<VectorType>(type) || return ::llvm::hash_value(arg.type);
isa<IntegerType>(type) || isa<OtherType>(type);
} }
template <typename U> bool Type::isa() const {
assert(type && "isa<> used on a null type.");
return U::kindof(getKind());
}
template <typename U> U Type::dyn_cast() const {
return isa<U>() ? U(type) : U(nullptr);
}
template <typename U> U Type::dyn_cast_or_null() const {
return (type && isa<U>()) ? U(type) : U(nullptr);
}
template <typename U> U Type::cast() const {
assert(isa<U>());
return U(type);
}
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type type) {
return type.isa<FloatType>() || type.isa<VectorType>() ||
type.isa<IntegerType>() || type.isa<OtherType>();
}
} // end namespace mlir } // end namespace mlir
namespace llvm {
// Type hash just like pointers.
template <> struct DenseMapInfo<mlir::Type> {
static mlir::Type getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
}
static mlir::Type getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
};
/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
template <> struct PointerLikeTypeTraits<mlir::Type> {
public:
static inline void *getAsVoidPointer(mlir::Type I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::Type getFromVoidPointer(void *P) {
return mlir::Type::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif // MLIR_IR_TYPES_H #endif // MLIR_IR_TYPES_H

View File

@ -104,15 +104,15 @@ class AllocOp
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> { : public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public: public:
/// The result of an alloc is always a MemRefType. /// The result of an alloc is always a MemRefType.
MemRefType *getType() const { MemRefType getType() const {
return cast<MemRefType>(getResult()->getType()); return getResult()->getType().cast<MemRefType>();
} }
static StringRef getOperationName() { return "alloc"; } static StringRef getOperationName() { return "alloc"; }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
MemRefType *memrefType, ArrayRef<SSAValue *> operands = {}); MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
bool verify() const; bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
@ -276,7 +276,7 @@ public:
const SSAValue *getSrcMemRef() const { return getOperand(0); } const SSAValue *getSrcMemRef() const { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType. // Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() const { unsigned getSrcMemRefRank() const {
return cast<MemRefType>(getSrcMemRef()->getType())->getRank(); return getSrcMemRef()->getType().cast<MemRefType>().getRank();
} }
// Returns the source memerf indices for this DMA operation. // Returns the source memerf indices for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator> llvm::iterator_range<Operation::const_operand_iterator>
@ -291,13 +291,13 @@ public:
} }
// Returns the rank (number of indices) of the destination MemRefType. // Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() const { unsigned getDstMemRefRank() const {
return cast<MemRefType>(getDstMemRef()->getType())->getRank(); return getDstMemRef()->getType().cast<MemRefType>().getRank();
} }
unsigned getSrcMemorySpace() const { unsigned getSrcMemorySpace() const {
return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace(); return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
} }
unsigned getDstMemorySpace() const { unsigned getDstMemorySpace() const {
return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace(); return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
} }
// Returns the destination memref indices for this DMA operation. // Returns the destination memref indices for this DMA operation.
@ -387,7 +387,7 @@ public:
// Returns the rank (number of indices) of the tag memref. // Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() const { unsigned getTagMemRefRank() const {
return cast<MemRefType>(getTagMemRef()->getType())->getRank(); return getTagMemRef()->getType().cast<MemRefType>().getRank();
} }
// Returns the number of elements transferred in the associated DMA operation. // Returns the number of elements transferred in the associated DMA operation.
@ -460,8 +460,8 @@ public:
SSAValue *getMemRef() { return getOperand(0); } SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); } const SSAValue *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); } void setMemRef(SSAValue *value) { setOperand(0, value); }
MemRefType *getMemRefType() const { MemRefType getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType()); return getMemRef()->getType().cast<MemRefType>();
} }
llvm::iterator_range<Operation::operand_iterator> getIndices() { llvm::iterator_range<Operation::operand_iterator> getIndices() {
@ -508,8 +508,8 @@ public:
static StringRef getOperationName() { return "memref_cast"; } static StringRef getOperationName() { return "memref_cast"; }
/// The result of a memref_cast is always a memref. /// The result of a memref_cast is always a memref.
MemRefType *getType() const { MemRefType getType() const {
return cast<MemRefType>(getResult()->getType()); return getResult()->getType().cast<MemRefType>();
} }
bool verify() const; bool verify() const;
@ -583,8 +583,8 @@ public:
SSAValue *getMemRef() { return getOperand(1); } SSAValue *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); } const SSAValue *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); } void setMemRef(SSAValue *value) { setOperand(1, value); }
MemRefType *getMemRefType() const { MemRefType getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType()); return getMemRef()->getType().cast<MemRefType>();
} }
llvm::iterator_range<Operation::operand_iterator> getIndices() { llvm::iterator_range<Operation::operand_iterator> getIndices() {
@ -671,8 +671,8 @@ public:
static StringRef getOperationName() { return "tensor_cast"; } static StringRef getOperationName() { return "tensor_cast"; }
/// The result of a tensor_cast is always a tensor. /// The result of a tensor_cast is always a tensor.
TensorType *getType() const { TensorType getType() const {
return cast<TensorType>(getResult()->getType()); return getResult()->getType().cast<TensorType>();
} }
bool verify() const; bool verify() const;

View File

@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
return tripCountExpr.getLargestKnownDivisor(); return tripCountExpr.getLargestKnownDivisor();
} }
bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType, bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
ArrayRef<MLValue *> indices, unsigned dim) { ArrayRef<MLValue *> indices, unsigned dim) {
assert(indices.size() == memRefType->getRank()); assert(indices.size() == memRefType.getRank());
assert(dim < indices.size()); assert(dim < indices.size());
auto layoutMap = memRefType->getAffineMaps(); auto layoutMap = memRefType.getAffineMaps();
assert(memRefType->getAffineMaps().size() <= 1); assert(memRefType.getAffineMaps().size() <= 1);
// TODO(ntv): remove dependency on Builder once we support non-identity // TODO(ntv): remove dependency on Builder once we support non-identity
// layout map. // layout map.
Builder b(memRefType->getContext()); Builder b(memRefType.getContext());
assert(layoutMap.empty() || assert(layoutMap.empty() ||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size())); layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
(void)layoutMap; (void)layoutMap;
@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input,
using namespace functional; using namespace functional;
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); }, auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
memoryOp->getIndices()); memoryOp->getIndices());
auto *memRefType = memoryOp->getMemRefType(); auto memRefType = memoryOp->getMemRefType();
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) { for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
if (fastestVaryingDim == (numIndices - 1) - d) { if (fastestVaryingDim == (numIndices - 1) - d) {
continue; continue;
@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input,
template <typename LoadOrStoreOpPointer> template <typename LoadOrStoreOpPointer>
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
auto *memRefType = memoryOp->getMemRefType(); auto memRefType = memoryOp->getMemRefType();
return isa<VectorType>(memRefType->getElementType()); return memRefType.getElementType().template isa<VectorType>();
} }
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) { bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {

View File

@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() {
// Verify that the argument list of the function and the arg list of the first // Verify that the argument list of the function and the arg list of the first
// block line up. // block line up.
auto fnInputTypes = fn.getType()->getInputs(); auto fnInputTypes = fn.getType().getInputs();
if (fnInputTypes.size() != firstBB->getNumArguments()) if (fnInputTypes.size() != firstBB->getNumArguments())
return failure("first block of cfgfunc must have " + return failure("first block of cfgfunc must have " +
Twine(fnInputTypes.size()) + Twine(fnInputTypes.size()) +
@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands,
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) { bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
// Verify that the return operands match the results of the function. // Verify that the return operands match the results of the function.
auto results = fn.getType()->getResults(); auto results = fn.getType().getResults();
if (inst.getNumOperands() != results.size()) if (inst.getNumOperands() != results.size())
return failure("return has " + Twine(inst.getNumOperands()) + return failure("return has " + Twine(inst.getNumOperands()) +
" operands, but enclosing function returns " + " operands, but enclosing function returns " +

View File

@ -122,7 +122,7 @@ private:
void visitForStmt(const ForStmt *forStmt); void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt); void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt); void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type); void visitType(Type type);
void visitAttribute(Attribute attr); void visitAttribute(Attribute attr);
void visitOperation(const Operation *op); void visitOperation(const Operation *op);
@ -135,16 +135,16 @@ private:
} // end anonymous namespace } // end anonymous namespace
// TODO Support visiting other types/instructions when implemented. // TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(const Type *type) { void ModuleState::visitType(Type type) {
if (auto *funcType = dyn_cast<FunctionType>(type)) { if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions. // Visit input and result types for functions.
for (auto *input : funcType->getInputs()) for (auto input : funcType.getInputs())
visitType(input); visitType(input);
for (auto *result : funcType->getResults()) for (auto result : funcType.getResults())
visitType(result); visitType(result);
} else if (auto *memref = dyn_cast<MemRefType>(type)) { } else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type. // Visit affine maps in memref type.
for (auto map : memref->getAffineMaps()) { for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map); recordAffineMapReference(map);
} }
} }
@ -271,7 +271,7 @@ public:
void print(const Module *module); void print(const Module *module);
void printFunctionReference(const Function *func); void printFunctionReference(const Function *func);
void printAttribute(Attribute attr); void printAttribute(Attribute attr);
void printType(const Type *type); void printType(Type type);
void print(const Function *fn); void print(const Function *fn);
void print(const ExtFunction *fn); void print(const ExtFunction *fn);
void print(const CFGFunction *fn); void print(const CFGFunction *fn);
@ -290,7 +290,7 @@ protected:
void printFunctionAttributes(const Function *fn); void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}); ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type); void printFunctionResultType(FunctionType type);
void printAffineMapId(int affineMapId) const; void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap); void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const; void printIntegerSetId(int integerSetId) const;
@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
} }
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
auto *type = attr.getType(); auto type = attr.getType();
auto shape = type->getShape(); auto shape = type.getShape();
auto rank = type->getRank(); auto rank = type.getRank();
SmallVector<Attribute, 16> elements; SmallVector<Attribute, 16> elements;
attr.getValues(elements); attr.getValues(elements);
@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
os << ']'; os << ']';
} }
void ModulePrinter::printType(const Type *type) { void ModulePrinter::printType(Type type) {
switch (type->getKind()) { switch (type.getKind()) {
case Type::Kind::Index: case Type::Kind::Index:
os << "index"; os << "index";
return; return;
@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
return; return;
case Type::Kind::Integer: { case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type); auto integer = type.cast<IntegerType>();
os << 'i' << integer->getWidth(); os << 'i' << integer.getWidth();
return; return;
} }
case Type::Kind::Function: { case Type::Kind::Function: {
auto *func = cast<FunctionType>(type); auto func = type.cast<FunctionType>();
os << '('; os << '(';
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); }); interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> "; os << ") -> ";
auto results = func->getResults(); auto results = func.getResults();
if (results.size() == 1) if (results.size() == 1)
os << *results[0]; os << results[0];
else { else {
os << '('; os << '(';
interleaveComma(results, [&](Type *type) { printType(type); }); interleaveComma(results, [&](Type type) { printType(type); });
os << ')'; os << ')';
} }
return; return;
} }
case Type::Kind::Vector: { case Type::Kind::Vector: {
auto *v = cast<VectorType>(type); auto v = type.cast<VectorType>();
os << "vector<"; os << "vector<";
for (auto dim : v->getShape()) for (auto dim : v.getShape())
os << dim << 'x'; os << dim << 'x';
os << *v->getElementType() << '>'; os << v.getElementType() << '>';
return; return;
} }
case Type::Kind::RankedTensor: { case Type::Kind::RankedTensor: {
auto *v = cast<RankedTensorType>(type); auto v = type.cast<RankedTensorType>();
os << "tensor<"; os << "tensor<";
for (auto dim : v->getShape()) { for (auto dim : v.getShape()) {
if (dim < 0) if (dim < 0)
os << '?'; os << '?';
else else
os << dim; os << dim;
os << 'x'; os << 'x';
} }
os << *v->getElementType() << '>'; os << v.getElementType() << '>';
return; return;
} }
case Type::Kind::UnrankedTensor: { case Type::Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(type); auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x"; os << "tensor<*x";
printType(v->getElementType()); printType(v.getElementType());
os << '>'; os << '>';
return; return;
} }
case Type::Kind::MemRef: { case Type::Kind::MemRef: {
auto *v = cast<MemRefType>(type); auto v = type.cast<MemRefType>();
os << "memref<"; os << "memref<";
for (auto dim : v->getShape()) { for (auto dim : v.getShape()) {
if (dim < 0) if (dim < 0)
os << '?'; os << '?';
else else
os << dim; os << dim;
os << 'x'; os << 'x';
} }
printType(v->getElementType()); printType(v.getElementType());
for (auto map : v->getAffineMaps()) { for (auto map : v.getAffineMaps()) {
os << ", "; os << ", ";
printAffineMapReference(map); printAffineMapReference(map);
} }
// Only print the memory space if it is the non-default one. // Only print the memory space if it is the non-default one.
if (v->getMemorySpace()) if (v.getMemorySpace())
os << ", " << v->getMemorySpace(); os << ", " << v.getMemorySpace();
os << '>'; os << '>';
return; return;
} }
@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
// Function printing // Function printing
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void ModulePrinter::printFunctionResultType(const FunctionType *type) { void ModulePrinter::printFunctionResultType(FunctionType type) {
switch (type->getResults().size()) { switch (type.getResults().size()) {
case 0: case 0:
break; break;
case 1: case 1:
os << " -> "; os << " -> ";
printType(type->getResults()[0]); printType(type.getResults()[0]);
break; break;
default: default:
os << " -> ("; os << " -> (";
interleaveComma(type->getResults(), interleaveComma(type.getResults(),
[&](Type *eltType) { printType(eltType); }); [&](Type eltType) { printType(eltType); });
os << ')'; os << ')';
break; break;
} }
@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType(); auto type = fn->getType();
os << "@" << fn->getName() << '('; os << "@" << fn->getName() << '(';
interleaveComma(type->getInputs(), interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
[&](Type *eltType) { printType(eltType); });
os << ')'; os << ')';
printFunctionResultType(type); printFunctionResultType(type);
@ -937,7 +936,7 @@ public:
// Implement OpAsmPrinter. // Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; } raw_ostream &getStream() const { return os; }
void printType(const Type *type) { ModulePrinter::printType(type); } void printType(Type type) { ModulePrinter::printType(type); }
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); } void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) { void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map); return ModulePrinter::printAffineMapReference(map);
@ -974,10 +973,10 @@ protected:
if (auto *op = value->getDefiningOperation()) { if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names. // i1 constants get special names.
if (intOp->getType()->isInteger(1)) { if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false"); specialName << (intOp->getValue() ? "true" : "false");
} else { } else {
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType(); specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
} }
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue(); specialName << 'c' << intOp->getValue();
@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const { void Type::print(raw_ostream &os) const {
ModuleState state(getContext()); ModuleState state(getContext());
ModulePrinter(os, state).printType(this); ModulePrinter(os, state).printType(*this);
} }
void Type::dump() const { print(llvm::errs()); } void Type::dump() const { print(llvm::errs()); }

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/TrailingObjects.h" #include "llvm/Support/TrailingObjects.h"
namespace mlir { namespace mlir {
@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
/// An attribute representing a reference to a type. /// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage { struct TypeAttributeStorage : public AttributeStorage {
Type *value; Type value;
}; };
/// An attribute representing a reference to a function. /// An attribute representing a reference to a function.
@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage {
/// A base attribute representing a reference to a vector or tensor constant. /// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage { struct ElementsAttributeStorage : public AttributeStorage {
VectorOrTensorType *type; VectorOrTensorType type;
}; };
/// An attribute representing a reference to a vector or tensor constant, /// An attribute representing a reference to a vector or tensor constant,

View File

@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const {
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Type *TypeAttr::getValue() const { Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
return static_cast<ImplType *>(attr)->value;
}
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value; return static_cast<ImplType *>(attr)->value;
} }
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); } FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
VectorOrTensorType *ElementsAttr::getType() const { VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type; return static_cast<ImplType *>(attr)->type;
} }
@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth; auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType()->getNumElements(); auto elementNum = getType().getNumElements();
auto context = getType()->getContext(); auto context = getType().getContext();
values.reserve(elementNum); values.reserve(elementNum);
if (bitsWidth == 64) { if (bitsWidth == 64) {
ArrayRef<int64_t> vs( ArrayRef<int64_t> vs(
@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {} : DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const { void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType()->getNumElements(); auto elementNum = getType().getNumElements();
auto context = getType()->getContext(); auto context = getType().getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()), ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8}); getRawData().size() / 8});
values.reserve(elementNum); values.reserve(elementNum);

View File

@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() {
// Argument list management. // Argument list management.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
BBArgument *BasicBlock::addArgument(Type *type) { BBArgument *BasicBlock::addArgument(Type type) {
auto *arg = new BBArgument(type, this); auto *arg = new BBArgument(type, this);
arguments.push_back(arg); arguments.push_back(arg);
return arg; return arg;
} }
/// Add one argument to the argument list for each type specified in the list. /// Add one argument to the argument list for each type specified in the list.
auto BasicBlock::addArguments(ArrayRef<Type *> types) auto BasicBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> { -> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size()); arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size(); auto initialSize = arguments.size();
for (auto *type : types) { for (auto type : types) {
addArgument(type); addArgument(type);
} }
return {arguments.data() + initialSize, arguments.data() + arguments.size()}; return {arguments.data() + initialSize, arguments.data() + arguments.size()};

View File

@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename,
// Types. // Types.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
FloatType *Builder::getBF16Type() { return Type::getBF16(context); } FloatType Builder::getBF16Type() { return Type::getBF16(context); }
FloatType *Builder::getF16Type() { return Type::getF16(context); } FloatType Builder::getF16Type() { return Type::getF16(context); }
FloatType *Builder::getF32Type() { return Type::getF32(context); } FloatType Builder::getF32Type() { return Type::getF32(context); }
FloatType *Builder::getF64Type() { return Type::getF64(context); } FloatType Builder::getF64Type() { return Type::getF64(context); }
OtherType *Builder::getIndexType() { return Type::getIndex(context); } OtherType Builder::getIndexType() { return Type::getIndex(context); }
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); } OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); } OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
OtherType *Builder::getTFComplex64Type() { OtherType Builder::getTFComplex64Type() {
return Type::getTFComplex64(context); return Type::getTFComplex64(context);
} }
OtherType *Builder::getTFComplex128Type() { OtherType Builder::getTFComplex128Type() {
return Type::getTFComplex128(context); return Type::getTFComplex128(context);
} }
OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); } OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
OtherType *Builder::getTFStringType() { return Type::getTFString(context); } OtherType Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType *Builder::getIntegerType(unsigned width) { IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context); return Type::getInteger(width, context);
} }
FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs, FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
ArrayRef<Type *> results) { ArrayRef<Type> results) {
return FunctionType::get(inputs, results, context); return FunctionType::get(inputs, results, context);
} }
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType, MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition, ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) { unsigned memorySpace) {
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
} }
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) { VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
return VectorType::get(shape, elementType); return VectorType::get(shape, elementType);
} }
RankedTensorType *Builder::getTensorType(ArrayRef<int> shape, RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
Type *elementType) {
return RankedTensorType::get(shape, elementType); return RankedTensorType::get(shape, elementType);
} }
UnrankedTensorType *Builder::getTensorType(Type *elementType) { UnrankedTensorType Builder::getTensorType(Type elementType) {
return UnrankedTensorType::get(elementType); return UnrankedTensorType::get(elementType);
} }
@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
return IntegerSetAttr::get(set); return IntegerSetAttr::get(set);
} }
TypeAttr Builder::getTypeAttr(Type *type) { TypeAttr Builder::getTypeAttr(Type type) {
return TypeAttr::get(type, context); return TypeAttr::get(type, context);
} }
@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context); return FunctionAttr::get(value, context);
} }
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
Attribute elt) { Attribute elt) {
return SplatElementsAttr::get(type, elt); return SplatElementsAttr::get(type, elt);
} }
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data) { ArrayRef<char> data) {
return DenseElementsAttr::get(type, data); return DenseElementsAttr::get(type, data);
} }
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices, DenseIntElementsAttr indices,
DenseElementsAttr values) { DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values); return SparseElementsAttr::get(type, indices, values);
} }
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
StringRef bytes) { StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes); return OpaqueElementsAttr::get(type, bytes);
} }
@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
OperationStmt *MLFuncBuilder::createOperation(Location *location, OperationStmt *MLFuncBuilder::createOperation(Location *location,
OperationName name, OperationName name,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<Type *> types, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs) { ArrayRef<NamedAttribute> attrs) {
auto *op = OperationStmt::create(location, name, operands, types, attrs, auto *op = OperationStmt::create(location, name, operands, types, attrs,
getContext()); getContext());

View File

@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
numDims = opInfos.size(); numDims = opInfos.size();
// Parse the optional symbol operands. // Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1, if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) || OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands)) parser->resolveOperands(opInfos, affineIntTy, operands))
@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder(); auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getIndexType(); auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr; AffineMapAttr mapAttr;
unsigned numDims; unsigned numDims;
@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
/// Builds a constant op with the specified attribute value and result type. /// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result, void ConstantOp::build(Builder *builder, OperationState *result,
Attribute value, Type *type) { Attribute value, Type type) {
result->addAttribute("value", value); result->addAttribute("value", value);
result->types.push_back(type); result->types.push_back(type);
} }
@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const {
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!getValue().isa<FunctionAttr>()) if (!getValue().isa<FunctionAttr>())
*p << " : " << *getType(); *p << " : " << getType();
} }
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr; Attribute valueAttr;
Type *type; Type type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) || if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes)) parser->parseOptionalAttributeDict(result->attributes))
@ -208,33 +208,33 @@ bool ConstantOp::verify() const {
if (!value) if (!value)
return emitOpError("requires a 'value' attribute"); return emitOpError("requires a 'value' attribute");
auto *type = this->getType(); auto type = this->getType();
if (isa<IntegerType>(type) || type->isIndex()) { if (type.isa<IntegerType>() || type.isIndex()) {
if (!value.isa<IntegerAttr>()) if (!value.isa<IntegerAttr>())
return emitOpError( return emitOpError(
"requires 'value' to be an integer for an integer result type"); "requires 'value' to be an integer for an integer result type");
return false; return false;
} }
if (isa<FloatType>(type)) { if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>()) if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant"); return emitOpError("requires 'value' to be a floating point constant");
return false; return false;
} }
if (isa<VectorOrTensorType>(type)) { if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>()) if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant"); return emitOpError("requires 'value' to be a vector/tensor constant");
return false; return false;
} }
if (type->isTFString()) { if (type.isTFString()) {
if (!value.isa<StringAttr>()) if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant"); return emitOpError("requires 'value' to be a string constant");
return false; return false;
} }
if (isa<FunctionType>(type)) { if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>()) if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference"); return emitOpError("requires 'value' to be a function reference");
return false; return false;
@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
} }
void ConstantFloatOp::build(Builder *builder, OperationState *result, void ConstantFloatOp::build(Builder *builder, OperationState *result,
const APFloat &value, FloatType *type) { const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type); ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
} }
bool ConstantFloatOp::isClassFor(const Operation *op) { bool ConstantFloatOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) && return ConstantOp::isClassFor(op) &&
isa<FloatType>(op->getResult(0)->getType()); op->getResult(0)->getType().isa<FloatType>();
} }
/// ConstantIntOp only matches values whose result type is an IntegerType. /// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) { bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) && return ConstantOp::isClassFor(op) &&
isa<IntegerType>(op->getResult(0)->getType()); op->getResult(0)->getType().isa<IntegerType>();
} }
void ConstantIntOp::build(Builder *builder, OperationState *result, void ConstantIntOp::build(Builder *builder, OperationState *result,
@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
/// Build a constant int op producing an integer with the specified type, /// Build a constant int op producing an integer with the specified type,
/// which must be an integer type. /// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result, void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, Type *type) { int64_t value, Type type) {
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type"); assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type); ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
} }
/// ConstantIndexOp only matches values whose result type is Index. /// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Operation *op) { bool ConstantIndexOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex(); return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
} }
void ConstantIndexOp::build(Builder *builder, OperationState *result, void ConstantIndexOp::build(Builder *builder, OperationState *result,
@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result,
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo; SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type *, 2> types; SmallVector<Type, 2> types;
llvm::SMLoc loc; llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) || (!opInfo.empty() && parser->parseColonTypeList(types)) ||
@ -330,7 +330,7 @@ bool ReturnOp::verify() const {
// The operand number and types must match the function signature. // The operand number and types must match the function signature.
MLFunction *function = cast<MLFunction>(block); MLFunction *function = cast<MLFunction>(block);
const auto &results = function->getType()->getResults(); const auto &results = function->getType().getResults();
if (stmt->getNumOperands() != results.size()) if (stmt->getNumOperands() != results.size())
return emitOpError("has " + Twine(stmt->getNumOperands()) + return emitOpError("has " + Twine(stmt->getNumOperands()) +
" operands, but enclosing function returns " + " operands, but enclosing function returns " +

View File

@ -28,8 +28,8 @@
using namespace mlir; using namespace mlir;
Function::Function(Kind kind, Location *location, StringRef name, Function::Function(Kind kind, Location *location, StringRef name,
FunctionType *type, ArrayRef<NamedAttribute> attrs) FunctionType type, ArrayRef<NamedAttribute> attrs)
: nameAndKind(Identifier::get(name, type->getContext()), kind), : nameAndKind(Identifier::get(name, type.getContext()), kind),
location(location), type(type) { location(location), type(type) {
this->attrs = AttributeListStorage::get(attrs, getContext()); this->attrs = AttributeListStorage::get(attrs, getContext());
} }
@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
return {}; return {};
} }
MLIRContext *Function::getContext() const { return getType()->getContext(); } MLIRContext *Function::getContext() const { return getType().getContext(); }
/// Delete this object. /// Delete this object.
void Function::destroy() { void Function::destroy() {
@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const {
// ExtFunction implementation. // ExtFunction implementation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) ArrayRef<NamedAttribute> attrs)
: Function(Kind::ExtFunc, location, name, type, attrs) {} : Function(Kind::ExtFunc, location, name, type, attrs) {}
@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
// CFGFunction implementation. // CFGFunction implementation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type, CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) ArrayRef<NamedAttribute> attrs)
: Function(Kind::CFGFunc, location, name, type, attrs) {} : Function(Kind::CFGFunc, location, name, type, attrs) {}
@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() {
/// Create a new MLFunction with the specific fields. /// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location *location, StringRef name, MLFunction *MLFunction::create(Location *location, StringRef name,
FunctionType *type, FunctionType type,
ArrayRef<NamedAttribute> attrs) { ArrayRef<NamedAttribute> attrs) {
const auto &argTypes = type->getInputs(); const auto &argTypes = type.getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size()); auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize); void *rawMem = malloc(byteSize);
@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
return function; return function;
} }
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type, MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs), : Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {} StmtBlock(StmtBlockKind::MLFunc) {}

View File

@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const {
/// Create a new OperationInst with the specified fields. /// Create a new OperationInst with the specified fields.
OperationInst *OperationInst::create(Location *location, OperationName name, OperationInst *OperationInst::create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands, ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes, ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
MLIRContext *context) { MLIRContext *context) {
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(), auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name,
OperationInst *OperationInst::clone() const { OperationInst *OperationInst::clone() const {
SmallVector<CFGValue *, 8> operands; SmallVector<CFGValue *, 8> operands;
SmallVector<Type *, 8> resultTypes; SmallVector<Type, 8> resultTypes;
// Put together the operands and results. // Put together the operands and results.
for (auto *operand : getOperands()) for (auto *operand : getOperands())

View File

@ -21,6 +21,7 @@
#include "AttributeDetail.h" #include "AttributeDetail.h"
#include "AttributeListStorage.h" #include "AttributeListStorage.h"
#include "IntegerSetDetail.h" #include "IntegerSetDetail.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -44,11 +45,11 @@ using namespace mlir::detail;
using namespace llvm; using namespace llvm;
namespace { namespace {
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> { struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
// Functions are uniqued based on their inputs and results. // Functions are uniqued based on their inputs and results.
using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>; using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
using DenseMapInfo<FunctionType *>::getHashValue; using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
using DenseMapInfo<FunctionType *>::isEqual; using DenseMapInfo<FunctionTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
hash_combine_range(key.second.begin(), key.second.end())); hash_combine_range(key.second.begin(), key.second.end()));
} }
static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) { static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == KeyTy(rhs->getInputs(), rhs->getResults()); return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
} }
}; };
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> { struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
// Vectors are uniqued based on their element type and shape. // Vectors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type *, ArrayRef<int>>; using KeyTy = std::pair<Type, ArrayRef<int>>;
using DenseMapInfo<VectorType *>::getHashValue; using DenseMapInfo<VectorTypeStorage *>::getHashValue;
using DenseMapInfo<VectorType *>::isEqual; using DenseMapInfo<VectorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
DenseMapInfo<Type *>::getHashValue(key.first), DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end())); hash_combine_range(key.second.begin(), key.second.end()));
} }
static bool isEqual(const KeyTy &lhs, const VectorType *rhs) { static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); return lhs == KeyTy(rhs->elementType, rhs->getShape());
} }
}; };
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> { struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
// Ranked tensors are uniqued based on their element type and shape. // Ranked tensors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type *, ArrayRef<int>>; using KeyTy = std::pair<Type, ArrayRef<int>>;
using DenseMapInfo<RankedTensorType *>::getHashValue; using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
using DenseMapInfo<RankedTensorType *>::isEqual; using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
DenseMapInfo<Type *>::getHashValue(key.first), DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end())); hash_combine_range(key.second.begin(), key.second.end()));
} }
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) { static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); return lhs == KeyTy(rhs->elementType, rhs->getShape());
} }
}; };
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> { struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
// MemRefs are uniqued based on their element type, shape, affine map // MemRefs are uniqued based on their element type, shape, affine map
// composition, and memory space. // composition, and memory space.
using KeyTy = using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>; using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
using DenseMapInfo<MemRefType *>::getHashValue; using DenseMapInfo<MemRefTypeStorage *>::isEqual;
using DenseMapInfo<MemRefType *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
DenseMapInfo<Type *>::getHashValue(std::get<0>(key)), DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()), hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
std::get<3>(key)); std::get<3>(key));
} }
static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) { static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(), return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
rhs->getAffineMaps(), rhs->getMemorySpace()); rhs->getAffineMaps(), rhs->memorySpace);
} }
}; };
@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
}; };
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> { struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>; using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue; using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual; using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
}; };
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> { struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, StringRef>; using KeyTy = std::pair<VectorOrTensorType, StringRef>;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue; using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual; using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
@ -295,13 +295,14 @@ public:
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers; llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
// Uniquing table for 'other' types. // Uniquing table for 'other' types.
OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) - OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr}; int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
nullptr};
// Uniquing table for 'float' types. // Uniquing table for 'float' types.
FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) - FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = { int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
nullptr}; {nullptr};
// Affine map uniquing. // Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@ -324,26 +325,26 @@ public:
DenseMap<int64_t, AffineConstantExprStorage *> constExprs; DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
/// Integer type uniquing. /// Integer type uniquing.
DenseMap<unsigned, IntegerType *> integers; DenseMap<unsigned, IntegerTypeStorage *> integers;
/// Function type uniquing. /// Function type uniquing.
using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>; using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
FunctionTypeSet functions; FunctionTypeSet functions;
/// Vector type uniquing. /// Vector type uniquing.
using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>; using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
VectorTypeSet vectors; VectorTypeSet vectors;
/// Ranked tensor type uniquing. /// Ranked tensor type uniquing.
using RankedTensorTypeSet = using RankedTensorTypeSet =
DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>; DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
RankedTensorTypeSet rankedTensors; RankedTensorTypeSet rankedTensors;
/// Unranked tensor type uniquing. /// Unranked tensor type uniquing.
DenseMap<Type *, UnrankedTensorType *> unrankedTensors; DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
/// MemRef type uniquing. /// MemRef type uniquing.
using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>; using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
MemRefTypeSet memrefs; MemRefTypeSet memrefs;
// Attribute uniquing. // Attribute uniquing.
@ -355,13 +356,12 @@ public:
ArrayAttrSet arrayAttrs; ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs; DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
DenseMap<Type *, TypeAttributeStorage *> typeAttrs; DenseMap<Type, TypeAttributeStorage *> typeAttrs;
using AttributeListSet = using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>; DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists; AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs; DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute>, DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
SplatElementsAttributeStorage *>
splatElementsAttrs; splatElementsAttrs;
using DenseElementsAttrSet = using DenseElementsAttrSet =
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>; DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@ -369,7 +369,7 @@ public:
using OpaqueElementsAttrSet = using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>; DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs; OpaqueElementsAttrSet opaqueElementsAttrs;
DenseMap<std::tuple<Type *, Attribute, Attribute>, DenseMap<std::tuple<Type, Attribute, Attribute>,
SparseElementsAttributeStorage *> SparseElementsAttributeStorage *>
sparseElementsAttrs; sparseElementsAttrs;
@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line,
// Type uniquing // Type uniquing
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) { IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
auto &impl = context->getImpl(); auto &impl = context->getImpl();
auto *&result = impl.integers[width]; auto *&result = impl.integers[width];
if (!result) { if (!result) {
result = impl.allocator.Allocate<IntegerType>(); result = impl.allocator.Allocate<IntegerTypeStorage>();
new (result) IntegerType(width, context); new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
} }
return result; return result;
} }
FloatType *FloatType::get(Kind kind, MLIRContext *context) { FloatType FloatType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE && assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind"); kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
auto &impl = context->getImpl(); auto &impl = context->getImpl();
@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) {
return entry; return entry;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<FloatType>(); auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (ptr) FloatType(kind, context); new (ptr) FloatTypeStorage{{kind, context}};
// Cache and return it. // Cache and return it.
return entry = ptr; return entry = ptr;
} }
OtherType *OtherType::get(Kind kind, MLIRContext *context) { OtherType OtherType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE && assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
"Not an 'other' type kind"); "Not an 'other' type kind");
auto &impl = context->getImpl(); auto &impl = context->getImpl();
@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) {
return entry; return entry;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<OtherType>(); auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (ptr) OtherType(kind, context); new (ptr) OtherTypeStorage{{kind, context}};
// Cache and return it. // Cache and return it.
return entry = ptr; return entry = ptr;
} }
FunctionType *FunctionType::get(ArrayRef<Type *> inputs, FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
ArrayRef<Type *> results, MLIRContext *context) {
MLIRContext *context) {
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if we already have this function type. // Look to see if we already have this function type.
@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
return *existing.first; return *existing.first;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<FunctionType>(); auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
// Copy the inputs and results into the bump pointer. // Copy the inputs and results into the bump pointer.
SmallVector<Type *, 16> types; SmallVector<Type, 16> types;
types.reserve(inputs.size() + results.size()); types.reserve(inputs.size() + results.size());
types.append(inputs.begin(), inputs.end()); types.append(inputs.begin(), inputs.end());
types.append(results.begin(), results.end()); types.append(results.begin(), results.end());
auto typesList = impl.copyInto(ArrayRef<Type *>(types)); auto typesList = impl.copyInto(ArrayRef<Type>(types));
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) new (result) FunctionTypeStorage{
FunctionType(typesList.data(), inputs.size(), results.size(), context); {Kind::Function, context, static_cast<unsigned int>(inputs.size())},
static_cast<unsigned int>(results.size()),
typesList.data()};
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) { VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
assert(!shape.empty() && "vector types must have at least one dimension"); assert(!shape.empty() && "vector types must have at least one dimension");
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) && assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
"vectors elements must be primitives"); "vectors elements must be primitives");
assert(!std::any_of(shape.begin(), shape.end(), [](int i) { assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
return i < 0; return i < 0;
}) && "vector types must have static shape"); }) && "vector types must have static shape");
auto *context = elementType->getContext(); auto *context = elementType.getContext();
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if we already have this vector type. // Look to see if we already have this vector type.
@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
return *existing.first; return *existing.first;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<VectorType>(); auto *result = impl.allocator.Allocate<VectorTypeStorage>();
// Copy the shape into the bump pointer. // Copy the shape into the bump pointer.
shape = impl.copyInto(shape); shape = impl.copyInto(shape);
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) VectorType(shape, elementType, context); new (result) VectorTypeStorage{
{{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
elementType},
shape.data()};
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape, RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
Type *elementType) { auto *context = elementType.getContext();
auto *context = elementType->getContext();
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if we already have this ranked tensor type. // Look to see if we already have this ranked tensor type.
@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
return *existing.first; return *existing.first;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<RankedTensorType>(); auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
// Copy the shape into the bump pointer. // Copy the shape into the bump pointer.
shape = impl.copyInto(shape); shape = impl.copyInto(shape);
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) RankedTensorType(shape, elementType, context); new (result) RankedTensorTypeStorage{
{{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
elementType}},
shape.data()};
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { UnrankedTensorType UnrankedTensorType::get(Type elementType) {
auto *context = elementType->getContext(); auto *context = elementType.getContext();
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type. // Look to see if we already have this unranked tensor type.
@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
return result; return result;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
result = impl.allocator.Allocate<UnrankedTensorType>(); result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) UnrankedTensorType(elementType, context); new (result) UnrankedTensorTypeStorage{
{{{Kind::UnrankedTensor, context}, elementType}}};
return result; return result;
} }
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType, MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition, ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) { unsigned memorySpace) {
auto *context = elementType->getContext(); auto *context = elementType.getContext();
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Drop the unbounded identity maps from the composition. // Drop the unbounded identity maps from the composition.
@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
return *existing.first; return *existing.first;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<MemRefType>(); auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
// Copy the shape into the bump pointer. // Copy the shape into the bump pointer.
shape = impl.copyInto(shape); shape = impl.copyInto(shape);
@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition)); impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace, new (result) MemRefTypeStorage{
context); {Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
elementType,
shape.data(),
static_cast<unsigned int>(affineMapComposition.size()),
affineMapComposition.data(),
memorySpace};
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
return result; return result;
} }
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type]; auto *&result = context->getImpl().typeAttrs[type];
if (result) if (result)
return result; return result;
@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result; return *existing.first = result;
} }
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type, SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) { Attribute elt) {
auto &impl = type->getContext()->getImpl(); auto &impl = type.getContext()->getImpl();
// Look to see if we already have this. // Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}]; auto *&result = impl.splatElementsAttrs[{type, elt}];
@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
return result; return result;
} }
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type, DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) { ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
(void)bitsRequired; (void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) && assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires"); "Input data bit size should be larger than that type requires");
auto &impl = type->getContext()->getImpl(); auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined. // Look to see if this constant is already defined.
DenseElementsAttrInfo::KeyTy key({type, data}); DenseElementsAttrInfo::KeyTy key({type, data});
@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first; return *existing.first;
// Otherwise, allocate a new one, unique it and return it. // Otherwise, allocate a new one, unique it and return it.
auto *eltType = type->getElementType(); auto eltType = type.getElementType();
switch (eltType->getKind()) { switch (eltType.getKind()) {
case Type::Kind::BF16: case Type::Kind::BF16:
case Type::Kind::F16: case Type::Kind::F16:
case Type::Kind::F32: case Type::Kind::F32:
@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result; return *existing.first = result;
} }
case Type::Kind::Integer: { case Type::Kind::Integer: {
auto width = ::cast<IntegerType>(eltType)->getWidth(); auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>(); auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy); std::uninitialized_copy(data.begin(), data.end(), copy);
@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
} }
} }
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type, OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
StringRef bytes) { StringRef bytes) {
assert(isValidTensorElementType(type->getElementType()) && assert(isValidTensorElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type"); "Input element type should be a valid tensor element type");
auto &impl = type->getContext()->getImpl(); auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined. // Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes}); OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result; return *existing.first = result;
} }
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type, SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
DenseIntElementsAttr indices, DenseIntElementsAttr indices,
DenseElementsAttr values) { DenseElementsAttr values) {
auto &impl = type->getContext()->getImpl(); auto &impl = type.getContext()->getImpl();
// Look to see if we already have this. // Look to see if we already have this.
auto key = std::make_tuple(type, indices, values); auto key = std::make_tuple(type, indices, values);

View File

@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
} }
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) { bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
auto *type = op->getResult(0)->getType(); auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type) if (op->getResult(i)->getType() != type)
return op->emitOpError( return op->emitOpError(
@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
/// If this is a vector type, or a tensor type, return the scalar element type /// If this is a vector type, or a tensor type, return the scalar element type
/// that it is built around, otherwise return the type unmodified. /// that it is built around, otherwise return the type unmodified.
static Type *getTensorOrVectorElementType(Type *type) { static Type getTensorOrVectorElementType(Type type) {
if (auto *vec = dyn_cast<VectorType>(type)) if (auto vec = type.dyn_cast<VectorType>())
return vec->getElementType(); return vec.getElementType();
// Look through tensor<vector<...>> to find the underlying element type. // Look through tensor<vector<...>> to find the underlying element type.
if (auto *tensor = dyn_cast<TensorType>(type)) if (auto tensor = type.dyn_cast<TensorType>())
return getTensorOrVectorElementType(tensor->getElementType()); return getTensorOrVectorElementType(tensor.getElementType());
return type; return type;
} }
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
for (auto *result : op->getResults()) { for (auto *result : op->getResults()) {
if (!isa<FloatType>(getTensorOrVectorElementType(result->getType()))) if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
return op->emitOpError("requires a floating point type"); return op->emitOpError("requires a floating point type");
} }
@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
for (auto *result : op->getResults()) { for (auto *result : op->getResults()) {
if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType()))) if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
return op->emitOpError("requires an integer type"); return op->emitOpError("requires an integer type");
} }
return false; return false;
@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result,
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type; Type type;
return parser->parseOperandList(ops, 2) || return parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << ", " *p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1); << *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs()); p->printOptionalAttrDict(op->getAttrs());
*p << " : " << *op->getResult(0)->getType(); *p << " : " << op->getResult(0)->getType();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result, void impl::buildCastOp(Builder *builder, OperationState *result,
SSAValue *source, Type *destType) { SSAValue *source, Type destType) {
result->addOperands(source); result->addOperands(source);
result->addTypes(destType); result->addTypes(destType);
} }
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo; OpAsmParser::OperandType srcInfo;
Type *srcType, *dstType; Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) || parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) || parser->parseKeywordType("to", dstType) ||
@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << " : " *p << op->getName() << ' ' << *op->getOperand(0) << " : "
<< *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType(); << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
} }

View File

@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block,
/// Create a new OperationStmt with the specific fields. /// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location *location, OperationName name, OperationStmt *OperationStmt::create(Location *location, OperationName name,
ArrayRef<MLValue *> operands, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes, ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
MLIRContext *context) { MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(), auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const {
// If we have a result or operand type, that is a constant time way to get // If we have a result or operand type, that is a constant time way to get
// to the context. // to the context.
if (getNumResults()) if (getNumResults())
return getResult(0)->getType()->getContext(); return getResult(0)->getType().getContext();
if (getNumOperands()) if (getNumOperands())
return getOperand(0)->getType()->getContext(); return getOperand(0)->getType().getContext();
// In the very odd case where we have no operands or results, fall back to // In the very odd case where we have no operands or results, fall back to
// doing a find. // doing a find.
@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const {
if (operands.empty()) if (operands.empty())
return findFunction()->getContext(); return findFunction()->getContext();
return getOperand(0)->getType()->getContext(); return getOperand(0)->getType().getContext();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
operands.push_back(remapOperand(opValue)); operands.push_back(remapOperand(opValue));
if (auto *opStmt = dyn_cast<OperationStmt>(this)) { if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
SmallVector<Type *, 8> resultTypes; SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults()); resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults()) for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType()); resultTypes.push_back(result->getType());

126
mlir/lib/IR/TypeDetail.h Normal file
View File

@ -0,0 +1,126 @@
//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of Type.
//
//===----------------------------------------------------------------------===//
#ifndef TYPEDETAIL_H_
#define TYPEDETAIL_H_
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
namespace mlir {
class AffineMap;
class MLIRContext;
namespace detail {
/// Base storage class appearing in a Type.
struct alignas(8) TypeStorage {
TypeStorage(Type::Kind kind, MLIRContext *context)
: context(context), kind(kind), subclassData(0) {}
TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
: context(context), kind(kind), subclassData(subclassData) {}
unsigned getSubclassData() const { return subclassData; }
void setSubclassData(unsigned val) {
subclassData = val;
// Ensure we don't have any accidental truncation.
assert(getSubclassData() == val && "Subclass data too large for field");
}
/// This refers to the MLIRContext in which this type was uniqued.
MLIRContext *const context;
/// Classification of the subclass, used for type checking.
Type::Kind kind : 8;
/// Space for subclasses to store data.
unsigned subclassData : 24;
};
struct IntegerTypeStorage : public TypeStorage {
unsigned width;
};
struct FloatTypeStorage : public TypeStorage {};
struct OtherTypeStorage : public TypeStorage {};
struct FunctionTypeStorage : public TypeStorage {
ArrayRef<Type> getInputs() const {
return ArrayRef<Type>(inputsAndResults, subclassData);
}
ArrayRef<Type> getResults() const {
return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
}
unsigned numResults;
Type const *inputsAndResults;
};
struct VectorOrTensorTypeStorage : public TypeStorage {
Type elementType;
};
struct VectorTypeStorage : public VectorOrTensorTypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
const int *shapeElements;
};
struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
struct RankedTensorTypeStorage : public TensorTypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
const int *shapeElements;
};
struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
struct MemRefTypeStorage : public TypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
ArrayRef<AffineMap> getAffineMaps() const {
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
}
/// The type of each scalar element of the memref.
Type elementType;
/// An array of integers which stores the shape dimension sizes.
const int *shapeElements;
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
AffineMap const *affineMapList;
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
};
} // namespace detail
} // namespace mlir
#endif // TYPEDETAIL_H_

View File

@ -16,10 +16,17 @@
// ============================================================================= // =============================================================================
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
using namespace mlir; using namespace mlir;
using namespace mlir::detail;
Type::Kind Type::getKind() const { return type->kind; }
MLIRContext *Type::getContext() const { return type->context; }
unsigned Type::getBitWidth() const { unsigned Type::getBitWidth() const {
switch (getKind()) { switch (getKind()) {
@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const {
case Type::Kind::F64: case Type::Kind::F64:
return 64; return 64;
case Type::Kind::Integer: case Type::Kind::Integer:
return cast<IntegerType>(this)->getWidth(); return cast<IntegerType>().getWidth();
case Type::Kind::Vector: case Type::Kind::Vector:
case Type::Kind::RankedTensor: case Type::Kind::RankedTensor:
case Type::Kind::UnrankedTensor: case Type::Kind::UnrankedTensor:
return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth(); return cast<VectorOrTensorType>().getElementType().getBitWidth();
// TODO: Handle more types. // TODO: Handle more types.
default: default:
llvm_unreachable("unexpected type"); llvm_unreachable("unexpected type");
} }
} }
IntegerType::IntegerType(unsigned width, MLIRContext *context) unsigned Type::getSubclassData() const { return type->getSubclassData(); }
: Type(Kind::Integer, context), width(width) { void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
unsigned IntegerType::getWidth() const {
return static_cast<ImplType *>(type)->width;
} }
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {} FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {} OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
unsigned numResults, MLIRContext *context)
: Type(Kind::Function, context, numInputs), numResults(numResults),
inputsAndResults(inputsAndResults) {}
VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, ArrayRef<Type> FunctionType::getInputs() const {
Type *elementType, unsigned subClassData) return static_cast<ImplType *>(type)->getInputs();
: Type(kind, context, subClassData), elementType(elementType) {} }
unsigned FunctionType::getNumResults() const {
return static_cast<ImplType *>(type)->numResults;
}
ArrayRef<Type> FunctionType::getResults() const {
return static_cast<ImplType *>(type)->getResults();
}
VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
Type VectorOrTensorType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
unsigned VectorOrTensorType::getNumElements() const { unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) { switch (getKind()) {
@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
ArrayRef<int> VectorOrTensorType::getShape() const { ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) { switch (getKind()) {
case Kind::Vector: case Kind::Vector:
return cast<VectorType>(this)->getShape(); return cast<VectorType>().getShape();
case Kind::RankedTensor: case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape(); return cast<RankedTensorType>().getShape();
case Kind::UnrankedTensor: case Kind::UnrankedTensor:
return cast<RankedTensorType>(this)->getShape(); return cast<RankedTensorType>().getShape();
default: default:
llvm_unreachable("not a VectorOrTensorType"); llvm_unreachable("not a VectorOrTensorType");
} }
@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const {
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
} }
VectorType::VectorType(ArrayRef<int> shape, Type *elementType, VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
MLIRContext *context)
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {}
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) ArrayRef<int> VectorType::getShape() const {
: VectorOrTensorType(kind, context, elementType) { return static_cast<ImplType *>(type)->getShape();
assert(isValidTensorElementType(elementType));
} }
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType, TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
MLIRContext *context)
: TensorType(Kind::RankedTensor, elementType, context), RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
shapeElements(shape.data()) {
setSubclassData(shape.size()); ArrayRef<int> RankedTensorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
} }
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
: TensorType(Kind::UnrankedTensor, elementType, context) {}
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType, MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context) ArrayRef<int> MemRefType::getShape() const {
: Type(Kind::MemRef, context, shape.size()), elementType(elementType), return static_cast<ImplType *>(type)->getShape();
shapeElements(shape.data()), numAffineMaps(affineMapList.size()), }
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
Type MemRefType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const { ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap>(affineMapList, numAffineMaps); return static_cast<ImplType *>(type)->getAffineMaps();
}
unsigned MemRefType::getMemorySpace() const {
return static_cast<ImplType *>(type)->memorySpace;
} }
unsigned MemRefType::getNumDynamicDims() const { unsigned MemRefType::getNumDynamicDims() const {

View File

@ -182,19 +182,19 @@ public:
// as the results of their action. // as the results of their action.
// Type parsing. // Type parsing.
VectorType *parseVectorType(); VectorType parseVectorType();
ParseResult parseXInDimensionList(); ParseResult parseXInDimensionList();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions); ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
Type *parseTensorType(); Type parseTensorType();
Type *parseMemRefType(); Type parseMemRefType();
Type *parseFunctionType(); Type parseFunctionType();
Type *parseType(); Type parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements); ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements); ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
// Attribute parsing. // Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type); FunctionType type);
Attribute parseAttribute(); Attribute parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -206,9 +206,9 @@ public:
AffineMap parseAffineMapReference(); AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference(); IntegerSet parseIntegerSetReference();
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type); DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector); DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
VectorOrTensorType *parseVectorOrTensorType(); VectorOrTensorType parseVectorOrTensorType();
private: private:
// The Parser is subclassed and reinstantiated. Do not add additional // The Parser is subclassed and reinstantiated. Do not add additional
@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
/// float-type ::= `f16` | `bf16` | `f32` | `f64` /// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// other-type ::= `index` | `tf_control` /// other-type ::= `index` | `tf_control`
/// ///
Type *Parser::parseType() { Type Parser::parseType() {
switch (getToken().getKind()) { switch (getToken().getKind()) {
default: default:
return (emitError("expected type"), nullptr); return (emitError("expected type"), nullptr);
@ -368,7 +368,7 @@ Type *Parser::parseType() {
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>` /// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
/// const-dimension-list ::= (integer-literal `x`)+ /// const-dimension-list ::= (integer-literal `x`)+
/// ///
VectorType *Parser::parseVectorType() { VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector); consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type")) if (parseToken(Token::less, "expected '<' in vector type"))
@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() {
// Parse the element type. // Parse the element type.
auto typeLoc = getToken().getLoc(); auto typeLoc = getToken().getLoc();
auto *elementType = parseType(); auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr; return nullptr;
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return (emitError(typeLoc, "invalid vector element type"), nullptr); return (emitError(typeLoc, "invalid vector element type"), nullptr);
return VectorType::get(dimensions, elementType); return VectorType::get(dimensions, elementType);
@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
/// tensor-type ::= `tensor` `<` dimension-list element-type `>` /// tensor-type ::= `tensor` `<` dimension-list element-type `>`
/// dimension-list ::= dimension-list-ranked | `*x` /// dimension-list ::= dimension-list-ranked | `*x`
/// ///
Type *Parser::parseTensorType() { Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor); consumeToken(Token::kw_tensor);
if (parseToken(Token::less, "expected '<' in tensor type")) if (parseToken(Token::less, "expected '<' in tensor type"))
@ -485,7 +485,7 @@ Type *Parser::parseTensorType() {
// Parse the element type. // Parse the element type.
auto typeLoc = getToken().getLoc(); auto typeLoc = getToken().getLoc();
auto *elementType = parseType(); auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr; return nullptr;
@ -505,7 +505,7 @@ Type *Parser::parseTensorType() {
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
/// memory-space ::= integer-literal /* | TODO: address-space-id */ /// memory-space ::= integer-literal /* | TODO: address-space-id */
/// ///
Type *Parser::parseMemRefType() { Type Parser::parseMemRefType() {
consumeToken(Token::kw_memref); consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type")) if (parseToken(Token::less, "expected '<' in memref type"))
@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() {
// Parse the element type. // Parse the element type.
auto typeLoc = getToken().getLoc(); auto typeLoc = getToken().getLoc();
auto *elementType = parseType(); auto elementType = parseType();
if (!elementType) if (!elementType)
return nullptr; return nullptr;
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) && if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
!isa<VectorType>(elementType)) !elementType.isa<VectorType>())
return (emitError(typeLoc, "invalid memref element type"), nullptr); return (emitError(typeLoc, "invalid memref element type"), nullptr);
// Parse semi-affine-map-composition. // Parse semi-affine-map-composition.
@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() {
/// ///
/// function-type ::= type-list-parens `->` type-list /// function-type ::= type-list-parens `->` type-list
/// ///
Type *Parser::parseFunctionType() { Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren)); assert(getToken().is(Token::l_paren));
SmallVector<Type *, 4> arguments, results; SmallVector<Type, 4> arguments, results;
if (parseTypeList(arguments) || if (parseTypeList(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") || parseToken(Token::arrow, "expected '->' in function type") ||
parseTypeList(results)) parseTypeList(results))
@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() {
/// ///
/// type-list-no-parens ::= type (`,` type)* /// type-list-no-parens ::= type (`,` type)*
/// ///
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) { ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult { auto parseElt = [&]() -> ParseResult {
auto elt = parseType(); auto elt = parseType();
elements.push_back(elt); elements.push_back(elt);
@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
/// type-list-parens ::= `(` `)` /// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)` /// | `(` type-list-no-parens `)`
/// ///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) { ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult { auto parseElt = [&]() -> ParseResult {
auto elt = parseType(); auto elt = parseType();
elements.push_back(elt); elements.push_back(elt);
@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
namespace { namespace {
class TensorLiteralParser { class TensorLiteralParser {
public: public:
TensorLiteralParser(Parser &p, Type *eltTy) TensorLiteralParser(Parser &p, Type eltTy)
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {} : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
ParseResult parse() { return parseList(shape); } ParseResult parse() { return parseList(shape); }
@ -676,7 +676,7 @@ private:
} }
Parser &p; Parser &p;
Type *eltTy; Type eltTy;
size_t currBitPos; size_t currBitPos;
size_t bitsWidth; size_t bitsWidth;
SmallVector<int, 4> shape; SmallVector<int, 4> shape;
@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
if (!result) if (!result)
return p.emitError("expected tensor element"); return p.emitError("expected tensor element");
// check result matches the element type. // check result matches the element type.
switch (eltTy->getKind()) { switch (eltTy.getKind()) {
case Type::Kind::BF16: case Type::Kind::BF16:
case Type::Kind::F16: case Type::Kind::F16:
case Type::Kind::F32: case Type::Kind::F32:
@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) {
/// synthesizing a forward reference) or emit an error and return null on /// synthesizing a forward reference) or emit an error and return null on
/// failure. /// failure.
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type) { FunctionType type) {
Identifier name = builder.getIdentifier(nameStr.drop_front()); Identifier name = builder.getIdentifier(nameStr.drop_front());
// See if the function has already been defined in the module. // See if the function has already been defined in the module.
@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::colon, "expected ':' and function type")) if (parseToken(Token::colon, "expected ':' and function type"))
return nullptr; return nullptr;
auto typeLoc = getToken().getLoc(); auto typeLoc = getToken().getLoc();
Type *type = parseType(); Type type = parseType();
if (!type) if (!type)
return nullptr; return nullptr;
auto *fnType = dyn_cast<FunctionType>(type); auto fnType = type.dyn_cast<FunctionType>();
if (!fnType) if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr); return (emitError(typeLoc, "expected function type"), nullptr);
@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() {
consumeToken(Token::kw_opaque); consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'")) if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr; return nullptr;
auto *type = parseVectorOrTensorType(); auto type = parseVectorOrTensorType();
if (!type) if (!type)
return nullptr; return nullptr;
auto val = getToken().getStringValue(); auto val = getToken().getStringValue();
@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "expected '<' after 'splat'")) if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr; return nullptr;
auto *type = parseVectorOrTensorType(); auto type = parseVectorOrTensorType();
if (!type) if (!type)
return nullptr; return nullptr;
switch (getToken().getKind()) { switch (getToken().getKind()) {
@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "expected '<' after 'dense'")) if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr; return nullptr;
auto *type = parseVectorOrTensorType(); auto type = parseVectorOrTensorType();
if (!type) if (!type)
return nullptr; return nullptr;
@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "Expected '<' after 'sparse'")) if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr; return nullptr;
auto *type = parseVectorOrTensorType(); auto type = parseVectorOrTensorType();
if (!type) if (!type)
return nullptr; return nullptr;
switch (getToken().getKind()) { switch (getToken().getKind()) {
case Token::l_square: { case Token::l_square: {
/// Parse indices /// Parse indices
auto *indicesEltType = builder.getIntegerType(32); auto indicesEltType = builder.getIntegerType(32);
auto indices = auto indices =
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type)); parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
if (parseToken(Token::comma, "expected ','")) if (parseToken(Token::comma, "expected ','"))
return nullptr; return nullptr;
/// Parse values. /// Parse values.
auto *valuesEltType = type->getElementType(); auto valuesEltType = type.getElementType();
auto values = auto values =
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type)); parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
/// Sanity check. /// Sanity check.
auto *indicesType = indices.getType(); auto indicesType = indices.getType();
auto *valuesType = values.getType(); auto valuesType = values.getType();
auto sameShape = (indicesType->getRank() == 1) || auto sameShape = (indicesType.getRank() == 1) ||
(type->getRank() == indicesType->getDimSize(1)); (type.getRank() == indicesType.getDimSize(1));
auto sameElementNum = auto sameElementNum =
indicesType->getDimSize(0) == valuesType->getDimSize(0); indicesType.getDimSize(0) == valuesType.getDimSize(0);
if (!sameShape || !sameElementNum) { if (!sameShape || !sameElementNum) {
std::string str; std::string str;
llvm::raw_string_ostream s(str); llvm::raw_string_ostream s(str);
s << "expected shape (["; s << "expected shape ([";
interleaveComma(type->getShape(), s); interleaveComma(type.getShape(), s);
s << "]); inferred shape of indices literal (["; s << "]); inferred shape of indices literal ([";
interleaveComma(indicesType->getShape(), s); interleaveComma(indicesType.getShape(), s);
s << "]); inferred shape of values literal (["; s << "]); inferred shape of values literal ([";
interleaveComma(valuesType->getShape(), s); interleaveComma(valuesType.getShape(), s);
s << "])"; s << "])";
return (emitError(s.str()), nullptr); return (emitError(s.str()), nullptr);
} }
@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() {
nullptr); nullptr);
} }
default: { default: {
if (Type *type = parseType()) if (Type type = parseType())
return builder.getTypeAttr(type); return builder.getTypeAttr(type);
return nullptr; return nullptr;
} }
@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() {
/// ///
/// This method returns a constructed dense elements attribute with the shape /// This method returns a constructed dense elements attribute with the shape
/// from the parsing result. /// from the parsing result.
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) { DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
TensorLiteralParser literalParser(*this, eltType); TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse()) if (literalParser.parse())
return nullptr; return nullptr;
VectorOrTensorType *type; VectorOrTensorType type;
if (isVector) { if (isVector) {
type = builder.getVectorType(literalParser.getShape(), eltType); type = builder.getVectorType(literalParser.getShape(), eltType);
} else { } else {
@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
/// This method compares the shapes from the parsing result and that from the /// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both /// input argument. It returns a constructed dense elements attribute if both
/// match. /// match.
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) { DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
auto *eltTy = type->getElementType(); auto eltTy = type.getElementType();
TensorLiteralParser literalParser(*this, eltTy); TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse()) if (literalParser.parse())
return nullptr; return nullptr;
if (literalParser.getShape() != type->getShape()) { if (literalParser.getShape() != type.getShape()) {
std::string str; std::string str;
llvm::raw_string_ostream s(str); llvm::raw_string_ostream s(str);
s << "inferred shape of elements literal (["; s << "inferred shape of elements literal ([";
interleaveComma(literalParser.getShape(), s); interleaveComma(literalParser.getShape(), s);
s << "]) does not match type (["; s << "]) does not match type ([";
interleaveComma(type->getShape(), s); interleaveComma(type.getShape(), s);
s << "])"; s << "])";
return (emitError(s.str()), nullptr); return (emitError(s.str()), nullptr);
} }
@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
/// vector-or-tensor-type ::= vector-type | tensor-type /// vector-or-tensor-type ::= vector-type | tensor-type
/// ///
/// This method also checks the type has static shape and ranked. /// This method also checks the type has static shape and ranked.
VectorOrTensorType *Parser::parseVectorOrTensorType() { VectorOrTensorType Parser::parseVectorOrTensorType() {
auto *type = dyn_cast<VectorOrTensorType>(parseType()); auto type = parseType().dyn_cast<VectorOrTensorType>();
if (!type) { if (!type) {
return (emitError("expected elements literal has a tensor or vector type"), return (emitError("expected elements literal has a tensor or vector type"),
nullptr); nullptr);
@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() {
if (parseToken(Token::comma, "expected ','")) if (parseToken(Token::comma, "expected ','"))
return nullptr; return nullptr;
if (!type->hasStaticShape() || type->getRank() == -1) { if (!type.hasStaticShape() || type.getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"), return (emitError("tensor literals must be ranked and have static shape"),
nullptr); nullptr);
} }
@ -1834,7 +1834,7 @@ public:
/// Given a reference to an SSA value and its type, return a reference. This /// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure. /// returns null on failure.
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type); SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
/// Register a definition of a value with the symbol table. /// Register a definition of a value with the symbol table.
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
@ -1845,11 +1845,11 @@ public:
template <typename ResultType> template <typename ResultType>
ResultType parseSSADefOrUseAndType( ResultType parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action); const std::function<ResultType(SSAUseInfo, Type)> &action);
SSAValue *parseSSAUseAndType() { SSAValue *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>( return parseSSADefOrUseAndType<SSAValue *>(
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * { [&](SSAUseInfo useInfo, Type type) -> SSAValue * {
return resolveSSAUse(useInfo, type); return resolveSSAUse(useInfo, type);
}); });
} }
@ -1880,7 +1880,7 @@ private:
/// their first reference, to allow checking for use of undefined values. /// their first reference, to allow checking for use of undefined values.
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders; DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type); SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference. /// Return true if this is a forward reference.
bool isForwardReferencePlaceholder(SSAValue *value) { bool isForwardReferencePlaceholder(SSAValue *value) {
@ -1891,7 +1891,7 @@ private:
/// Create and remember a new placeholder for a forward reference. /// Create and remember a new placeholder for a forward reference.
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
Type *type) { Type type) {
// Forward references are always created as instructions, even in ML // Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain. // functions, because we just need something with a def/use chain.
// //
@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
/// Given an unbound reference to an SSA value and its type, return the value /// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure. /// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) { SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = values[useInfo.name]; auto &entries = values[useInfo.name];
// If we have already seen a value of this name, return it. // If we have already seen a value of this name, return it.
@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
/// ssa-use-and-type ::= ssa-use `:` type /// ssa-use-and-type ::= ssa-use `:` type
template <typename ResultType> template <typename ResultType>
ResultType FunctionParser::parseSSADefOrUseAndType( ResultType FunctionParser::parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action) { const std::function<ResultType(SSAUseInfo, Type)> &action) {
SSAUseInfo useInfo; SSAUseInfo useInfo;
if (parseSSAUse(useInfo) || if (parseSSAUse(useInfo) ||
parseToken(Token::colon, "expected ':' and type for SSA operand")) parseToken(Token::colon, "expected ':' and type for SSA operand"))
return nullptr; return nullptr;
auto *type = parseType(); auto type = parseType();
if (!type) if (!type)
return nullptr; return nullptr;
@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
if (valueIDs.empty()) if (valueIDs.empty())
return ParseSuccess; return ParseSuccess;
SmallVector<Type *, 4> types; SmallVector<Type, 4> types;
if (parseToken(Token::colon, "expected ':' in operand list") || if (parseToken(Token::colon, "expected ':' in operand list") ||
parseTypeListNoParens(types)) parseTypeListNoParens(types))
return ParseFailure; return ParseFailure;
@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation(
auto type = parseType(); auto type = parseType();
if (!type) if (!type)
return nullptr; return nullptr;
auto fnType = dyn_cast<FunctionType>(type); auto fnType = type.dyn_cast<FunctionType>();
if (!fnType) if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr); return (emitError(typeLoc, "expected function type"), nullptr);
result.addTypes(fnType->getResults()); result.addTypes(fnType.getResults());
// Check that we have the right number of types for the operands. // Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs(); auto operandTypes = fnType.getInputs();
if (operandTypes.size() != operandInfos.size()) { if (operandTypes.size() != operandInfos.size()) {
auto plural = "s"[operandInfos.size() == 1]; auto plural = "s"[operandInfos.size() == 1];
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) + return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
@ -2253,17 +2253,17 @@ public:
return parser.parseToken(Token::comma, "expected ','"); return parser.parseToken(Token::comma, "expected ','");
} }
bool parseColonType(Type *&result) override { bool parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") || return parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType()); !(result = parser.parseType());
} }
bool parseColonTypeList(SmallVectorImpl<Type *> &result) override { bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'")) if (parser.parseToken(Token::colon, "expected ':'"))
return true; return true;
do { do {
if (auto *type = parser.parseType()) if (auto type = parser.parseType())
result.push_back(type); result.push_back(type);
else else
return true; return true;
@ -2273,7 +2273,7 @@ public:
} }
/// Parse a keyword followed by a type. /// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type *&result) override { bool parseKeywordType(const char *keyword, Type &result) override {
if (parser.getTokenSpelling() != keyword) if (parser.getTokenSpelling() != keyword)
return parser.emitError("expected '" + Twine(keyword) + "'"); return parser.emitError("expected '" + Twine(keyword) + "'");
parser.consumeToken(); parser.consumeToken();
@ -2396,7 +2396,7 @@ public:
} }
/// Resolve a parse function name and a type into a function reference. /// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type, virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) { llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type); result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr; return result == nullptr;
@ -2410,7 +2410,7 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; } llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type *type, bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) override { SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location}; operand.location};
@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
return ParseSuccess; return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult { return parseCommaSeparatedList([&]() -> ParseResult {
auto type = parseSSADefOrUseAndType<Type *>( auto type = parseSSADefOrUseAndType<Type>(
[&](SSAUseInfo useInfo, Type *type) -> Type * { [&](SSAUseInfo useInfo, Type type) -> Type {
BBArgument *arg = owner->addArgument(type); BBArgument *arg = owner->addArgument(type);
if (addDefinition(useInfo, arg)) if (addDefinition(useInfo, arg))
return nullptr; return {};
return type; return type;
}); });
return type ? ParseSuccess : ParseFailure; return type ? ParseSuccess : ParseFailure;
@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
" symbol count must match"); " symbol count must match");
// Resolve SSA uses. // Resolve SSA uses.
Type *indexType = builder.getIndexType(); Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], indexType); SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval) if (!sval)
@ -3187,9 +3187,9 @@ private:
ParseResult parseAffineStructureDef(); ParseResult parseAffineStructureDef();
// Functions. // Functions.
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames); SmallVectorImpl<StringRef> &argNames);
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type, ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames); SmallVectorImpl<StringRef> *argNames);
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs); ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
ParseResult parseExtFunc(); ParseResult parseExtFunc();
@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() {
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
/// ///
ParseResult ParseResult
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes, ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames) { SmallVectorImpl<StringRef> &argNames) {
consumeToken(Token::l_paren); consumeToken(Token::l_paren);
@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
/// type-list)? /// type-list)?
/// ///
ParseResult ParseResult
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames) { SmallVectorImpl<StringRef> *argNames) {
if (getToken().isNot(Token::at_identifier)) if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'"); return emitError("expected a function identifier like '@foo'");
@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
if (getToken().isNot(Token::l_paren)) if (getToken().isNot(Token::l_paren))
return emitError("expected '(' in function signature"); return emitError("expected '(' in function signature");
SmallVector<Type *, 4> argTypes; SmallVector<Type, 4> argTypes;
ParseResult parseResult; ParseResult parseResult;
if (argNames) if (argNames)
@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
return ParseFailure; return ParseFailure;
// Parse the return type if present. // Parse the return type if present.
SmallVector<Type *, 4> results; SmallVector<Type, 4> results;
if (consumeIf(Token::arrow)) { if (consumeIf(Token::arrow)) {
if (parseTypeList(results)) if (parseTypeList(results))
return ParseFailure; return ParseFailure;
@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() {
auto loc = getToken().getLoc(); auto loc = getToken().getLoc();
StringRef name; StringRef name;
FunctionType *type = nullptr; FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure; return ParseFailure;
@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() {
auto loc = getToken().getLoc(); auto loc = getToken().getLoc();
StringRef name; StringRef name;
FunctionType *type = nullptr; FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure; return ParseFailure;
@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() {
consumeToken(Token::kw_mlfunc); consumeToken(Token::kw_mlfunc);
StringRef name; StringRef name;
FunctionType *type = nullptr; FunctionType type;
SmallVector<StringRef, 4> argNames; SmallVector<StringRef, 4> argNames;
auto loc = getToken().getLoc(); auto loc = getToken().getLoc();

View File

@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result, void AllocOp::build(Builder *builder, OperationState *result,
MemRefType *memrefType, ArrayRef<SSAValue *> operands) { MemRefType memrefType, ArrayRef<SSAValue *> operands) {
result->addOperands(operands); result->addOperands(operands);
result->types.push_back(memrefType); result->types.push_back(memrefType);
} }
void AllocOp::print(OpAsmPrinter *p) const { void AllocOp::print(OpAsmPrinter *p) const {
MemRefType *type = getType(); MemRefType type = getType();
*p << "alloc"; *p << "alloc";
// Print dynamic dimension operands. // Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(), printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p); type.getNumDynamicDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
*p << " : " << *type; *p << " : " << type;
} }
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType *type; MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a // Parse the dimension operands and optional symbol operands, followed by a
// memref type. // memref type.
@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
// Verification still checks that the total number of operands matches // Verification still checks that the total number of operands matches
// the number of symbols in the affine map, plus the number of dynamic // the number of symbols in the affine map, plus the number of dynamic
// dimensions in the memref. // dimensions in the memref.
if (numDimOperands != type->getNumDynamicDims()) { if (numDimOperands != type.getNumDynamicDims()) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"dimension operand count does not equal memref " "dimension operand count does not equal memref "
"dynamic dimension count"); "dynamic dimension count");
@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
} }
bool AllocOp::verify() const { bool AllocOp::verify() const {
auto *memRefType = dyn_cast<MemRefType>(getResult()->getType()); auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
if (!memRefType) if (!memRefType)
return emitOpError("result must be a memref"); return emitOpError("result must be a memref");
unsigned numSymbols = 0; unsigned numSymbols = 0;
if (!memRefType->getAffineMaps().empty()) { if (!memRefType.getAffineMaps().empty()) {
AffineMap affineMap = memRefType->getAffineMaps()[0]; AffineMap affineMap = memRefType.getAffineMaps()[0];
// Store number of symbols used in affine map (used in subsequent check). // Store number of symbols used in affine map (used in subsequent check).
numSymbols = affineMap.getNumSymbols(); numSymbols = affineMap.getNumSymbols();
// TODO(zinenko): this check does not belong to AllocOp, or any other op but // TODO(zinenko): this check does not belong to AllocOp, or any other op but
@ -195,10 +195,10 @@ bool AllocOp::verify() const {
// Remove when we can emit errors directly from *Type::get(...) functions. // Remove when we can emit errors directly from *Type::get(...) functions.
// //
// Verify that the layout affine map matches the rank of the memref. // Verify that the layout affine map matches the rank of the memref.
if (affineMap.getNumDims() != memRefType->getRank()) if (affineMap.getNumDims() != memRefType.getRank())
return emitOpError("affine map dimension count must equal memref rank"); return emitOpError("affine map dimension count must equal memref rank");
} }
unsigned numDynamicDims = memRefType->getNumDynamicDims(); unsigned numDynamicDims = memRefType.getNumDynamicDims();
// Check that the total number of operands matches the number of symbols in // Check that the total number of operands matches the number of symbols in
// the affine map, plus the number of dynamic dimensions specified in the // the affine map, plus the number of dynamic dimensions specified in the
// memref type. // memref type.
@ -208,7 +208,7 @@ bool AllocOp::verify() const {
} }
// Verify that all operands are of type Index. // Verify that all operands are of type Index.
for (auto *operand : getOperands()) { for (auto *operand : getOperands()) {
if (!operand->getType()->isIndex()) if (!operand->getType().isIndex())
return emitOpError("requires operands to be of type Index"); return emitOpError("requires operands to be of type Index");
} }
return false; return false;
@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern {
// Ok, we have one or more constant operands. Collect the non-constant ones // Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build. // and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants; SmallVector<int, 4> newShapeConstants;
newShapeConstants.reserve(memrefType->getRank()); newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands; SmallVector<SSAValue *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands; SmallVector<SSAValue *, 4> droppedOperands;
unsigned dynamicDimPos = 0; unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) { for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
int dimSize = memrefType->getDimSize(dim); int dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it. // If this is already static dimension, keep it.
if (dimSize != -1) { if (dimSize != -1) {
newShapeConstants.push_back(dimSize); newShapeConstants.push_back(dimSize);
@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern {
} }
// Create new memref type (which will have fewer dynamic dimensions). // Create new memref type (which will have fewer dynamic dimensions).
auto *newMemRefType = MemRefType::get( auto newMemRefType = MemRefType::get(
newShapeConstants, memrefType->getElementType(), newShapeConstants, memrefType.getElementType(),
memrefType->getAffineMaps(), memrefType->getMemorySpace()); memrefType.getAffineMaps(), memrefType.getMemorySpace());
assert(newOperands.size() == newMemRefType->getNumDynamicDims()); assert(newOperands.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref. // Create and insert the alloc op for the new memref.
auto newAlloc = auto newAlloc =
@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) { ArrayRef<SSAValue *> operands) {
result->addOperands(operands); result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee)); result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType()->getResults()); result->addTypes(callee->getType().getResults());
} }
bool CallOp::parse(OpAsmParser *parser, OperationState *result) { bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName; StringRef calleeName;
llvm::SMLoc calleeLoc; llvm::SMLoc calleeLoc;
FunctionType *calleeType = nullptr; FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands; SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr; Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) || if (parser->parseFunctionName(calleeName, calleeLoc) ||
@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) || parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) || parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
parser->addTypesToList(calleeType->getResults(), result->types) || parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc, parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands)) result->operands))
return true; return true;
@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const {
p->printOperands(getOperands()); p->printOperands(getOperands());
*p << ')'; *p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType(); *p << " : " << getCallee()->getType();
} }
bool CallOp::verify() const { bool CallOp::verify() const {
@ -338,20 +338,20 @@ bool CallOp::verify() const {
return emitOpError("requires a 'callee' function attribute"); return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee. // Verify that the operand and result types match the callee.
auto *fnType = fnAttr.getValue()->getType(); auto fnType = fnAttr.getValue()->getType();
if (fnType->getNumInputs() != getNumOperands()) if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee"); return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i)->getType() != fnType->getInput(i)) if (getOperand(i)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch"); return emitOpError("operand type mismatch");
} }
if (fnType->getNumResults() != getNumResults()) if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee"); return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i)) if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch"); return emitOpError("result type mismatch");
} }
@ -364,14 +364,14 @@ bool CallOp::verify() const {
void CallIndirectOp::build(Builder *builder, OperationState *result, void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) { SSAValue *callee, ArrayRef<SSAValue *> operands) {
auto *fnType = cast<FunctionType>(callee->getType()); auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee); result->operands.push_back(callee);
result->addOperands(operands); result->addOperands(operands);
result->addTypes(fnType->getResults()); result->addTypes(fnType.getResults());
} }
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType *calleeType = nullptr; FunctionType calleeType;
OpAsmParser::OperandType callee; OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc; llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands; SmallVector<OpAsmParser::OperandType, 4> operands;
@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) || parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) || parser->resolveOperand(callee, calleeType, result->operands) ||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc, parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) || result->operands) ||
parser->addTypesToList(calleeType->getResults(), result->types); parser->addTypesToList(calleeType.getResults(), result->types);
} }
void CallIndirectOp::print(OpAsmPrinter *p) const { void CallIndirectOp::print(OpAsmPrinter *p) const {
@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const {
p->printOperands(++operandRange.begin(), operandRange.end()); p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')'; *p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType(); *p << " : " << getCallee()->getType();
} }
bool CallIndirectOp::verify() const { bool CallIndirectOp::verify() const {
// The callee must be a function. // The callee must be a function.
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType()); auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
if (!fnType) if (!fnType)
return emitOpError("callee must have function type"); return emitOpError("callee must have function type");
// Verify that the operand and result types match the callee. // Verify that the operand and result types match the callee.
if (fnType->getNumInputs() != getNumOperands() - 1) if (fnType.getNumInputs() != getNumOperands() - 1)
return emitOpError("incorrect number of operands for callee"); return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) { for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i + 1)->getType() != fnType->getInput(i)) if (getOperand(i + 1)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch"); return emitOpError("operand type mismatch");
} }
if (fnType->getNumResults() != getNumResults()) if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee"); return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) { for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i)) if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch"); return emitOpError("result type mismatch");
} }
@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result,
} }
void DeallocOp::print(OpAsmPrinter *p) const { void DeallocOp::print(OpAsmPrinter *p) const {
*p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType(); *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
} }
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
MemRefType *type; MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands); parser->resolveOperand(memrefInfo, type, result->operands);
} }
bool DeallocOp::verify() const { bool DeallocOp::verify() const {
if (!isa<MemRefType>(getMemRef()->getType())) if (!getMemRef()->getType().isa<MemRefType>())
return emitOpError("operand must be a memref"); return emitOpError("operand must be a memref");
return false; return false;
} }
@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result,
void DimOp::print(OpAsmPrinter *p) const { void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex(); *p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
*p << " : " << *getOperand()->getType(); *p << " : " << getOperand()->getType();
} }
bool DimOp::parse(OpAsmParser *parser, OperationState *result) { bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo; OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr; IntegerAttr indexAttr;
Type *type; Type type;
return parser->parseOperand(operandInfo) || parser->parseComma() || return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, "index", result->attributes) || parser->parseAttribute(indexAttr, "index", result->attributes) ||
@ -496,15 +496,15 @@ bool DimOp::verify() const {
return emitOpError("requires an integer attribute named 'index'"); return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr.getValue(); uint64_t index = (uint64_t)indexAttr.getValue();
auto *type = getOperand()->getType(); auto type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index >= tensorType->getRank()) if (index >= tensorType.getRank())
return emitOpError("index is out of range"); return emitOpError("index is out of range");
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) { } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index >= memrefType->getRank()) if (index >= memrefType.getRank())
return emitOpError("index is out of range"); return emitOpError("index is out of range");
} else if (isa<UnrankedTensorType>(type)) { } else if (type.isa<UnrankedTensorType>()) {
// ok, assumed to be in-range. // ok, assumed to be in-range.
} else { } else {
return emitOpError("requires an operand with tensor or memref type"); return emitOpError("requires an operand with tensor or memref type");
@ -516,12 +516,12 @@ bool DimOp::verify() const {
Attribute DimOp::constantFold(ArrayRef<Attribute> operands, Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant. // Constant fold dim when the size along the index referred to is a constant.
auto *opType = getOperand()->getType(); auto opType = getOperand()->getType();
int indexSize = -1; int indexSize = -1;
if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) { if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
indexSize = tensorType->getShape()[getIndex()]; indexSize = tensorType.getShape()[getIndex()];
} else if (auto *memrefType = dyn_cast<MemRefType>(opType)) { } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
indexSize = memrefType->getShape()[getIndex()]; indexSize = memrefType.getShape()[getIndex()];
} }
if (indexSize >= 0) if (indexSize >= 0)
@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices()); p->printOperands(getTagIndices());
*p << ']'; *p << ']';
p->printOptionalAttrDict(getAttrs()); p->printOptionalAttrDict(getAttrs());
*p << " : " << *getSrcMemRef()->getType(); *p << " : " << getSrcMemRef()->getType();
*p << ", " << *getDstMemRef()->getType(); *p << ", " << getDstMemRef()->getType();
*p << ", " << *getTagMemRef()->getType(); *p << ", " << getTagMemRef()->getType();
} }
// Parse DmaStartOp. // Parse DmaStartOp.
@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo; OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
SmallVector<Type *, 3> types; SmallVector<Type, 3> types;
auto *indexType = parser->getBuilder().getIndexType(); auto indexType = parser->getBuilder().getIndexType();
// Parse and resolve the following list of operands: // Parse and resolve the following list of operands:
// *) source memref followed by its indices (in square brackets). // *) source memref followed by its indices (in square brackets).
@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return true; return true;
// Check that source/destination index list size matches associated rank. // Check that source/destination index list size matches associated rank.
if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() || if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank()) dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"memref rank not equal to indices count"); "memref rank not equal to indices count");
if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank()) if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count"); "tag memref rank not equal to indices count");
@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices()); p->printOperands(getTagIndices());
*p << "], "; *p << "], ";
p->printOperand(getNumElements()); p->printOperand(getNumElements());
*p << " : " << *getTagMemRef()->getType(); *p << " : " << getTagMemRef()->getType();
} }
// Parse DmaWaitOp. // Parse DmaWaitOp.
@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo; OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type *type; Type type;
auto *indexType = parser->getBuilder().getIndexType(); auto indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo; OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its indices, and dma size. // Parse tag memref, its indices, and dma size.
@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(numElementsInfo, indexType, result->operands)) parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true; return true;
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank()) if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count"); "tag memref rank not equal to indices count");
@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
void ExtractElementOp::build(Builder *builder, OperationState *result, void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate, SSAValue *aggregate,
ArrayRef<SSAValue *> indices) { ArrayRef<SSAValue *> indices) {
auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType()); auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate); result->addOperands(aggregate);
result->addOperands(indices); result->addOperands(indices);
result->types.push_back(aggregateType->getElementType()); result->types.push_back(aggregateType.getElementType());
} }
void ExtractElementOp::print(OpAsmPrinter *p) const { void ExtractElementOp::print(OpAsmPrinter *p) const {
@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices()); p->printOperands(getIndices());
*p << ']'; *p << ']';
p->printOptionalAttrDict(getAttrs()); p->printOptionalAttrDict(getAttrs());
*p << " : " << *getAggregate()->getType(); *p << " : " << getAggregate()->getType();
} }
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType aggregateInfo; OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
VectorOrTensorType *type; VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) || return parser->parseOperand(aggregateInfo) ||
@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) || parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type->getElementType(), result->types); parser->addTypeToList(type.getElementType(), result->types);
} }
bool ExtractElementOp::verify() const { bool ExtractElementOp::verify() const {
if (getNumOperands() == 0) if (getNumOperands() == 0)
return emitOpError("expected an aggregate to index into"); return emitOpError("expected an aggregate to index into");
auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType()); auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
if (!aggregateType) if (!aggregateType)
return emitOpError("first operand must be a vector or tensor"); return emitOpError("first operand must be a vector or tensor");
if (getType() != aggregateType->getElementType()) if (getType() != aggregateType.getElementType())
return emitOpError("result type must match element type of aggregate"); return emitOpError("result type must match element type of aggregate");
for (auto *idx : getIndices()) for (auto *idx : getIndices())
if (!idx->getType()->isIndex()) if (!idx->getType().isIndex())
return emitOpError("index to extract_element must have 'index' type"); return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type. // Verify the # indices match if we have a ranked type.
auto aggregateRank = aggregateType->getRank(); auto aggregateRank = aggregateType.getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element"); return emitOpError("incorrect number of indices for extract_element");
@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const {
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices) { ArrayRef<SSAValue *> indices) {
auto *memrefType = cast<MemRefType>(memref->getType()); auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref); result->addOperands(memref);
result->addOperands(indices); result->addOperands(indices);
result->types.push_back(memrefType->getElementType()); result->types.push_back(memrefType.getElementType());
} }
void LoadOp::print(OpAsmPrinter *p) const { void LoadOp::print(OpAsmPrinter *p) const {
@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices()); p->printOperands(getIndices());
*p << ']'; *p << ']';
p->printOptionalAttrDict(getAttrs()); p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRefType(); *p << " : " << getMemRefType();
} }
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type; MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) || return parser->parseOperand(memrefInfo) ||
@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type->getElementType(), result->types); parser->addTypeToList(type.getElementType(), result->types);
} }
bool LoadOp::verify() const { bool LoadOp::verify() const {
if (getNumOperands() == 0) if (getNumOperands() == 0)
return emitOpError("expected a memref to load from"); return emitOpError("expected a memref to load from");
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType) if (!memRefType)
return emitOpError("first operand must be a memref"); return emitOpError("first operand must be a memref");
if (getType() != memRefType->getElementType()) if (getType() != memRefType.getElementType())
return emitOpError("result type must match element type of memref"); return emitOpError("result type must match element type of memref");
if (memRefType->getRank() != getNumOperands() - 1) if (memRefType.getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load"); return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices()) for (auto *idx : getIndices())
if (!idx->getType()->isIndex()) if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type"); return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices. // TODO: Verify we have the right number of indices.
@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool MemRefCastOp::verify() const { bool MemRefCastOp::verify() const {
auto *opType = dyn_cast<MemRefType>(getOperand()->getType()); auto opType = getOperand()->getType().dyn_cast<MemRefType>();
auto *resType = dyn_cast<MemRefType>(getType()); auto resType = getType().dyn_cast<MemRefType>();
if (!opType || !resType) if (!opType || !resType)
return emitOpError("requires input and result types to be memrefs"); return emitOpError("requires input and result types to be memrefs");
if (opType == resType) if (opType == resType)
return emitOpError("requires the input and result type to be different"); return emitOpError("requires the input and result type to be different");
if (opType->getElementType() != resType->getElementType()) if (opType.getElementType() != resType.getElementType())
return emitOpError( return emitOpError(
"requires input and result element types to be the same"); "requires input and result element types to be the same");
if (opType->getAffineMaps() != resType->getAffineMaps()) if (opType.getAffineMaps() != resType.getAffineMaps())
return emitOpError("requires input and result mappings to be the same"); return emitOpError("requires input and result mappings to be the same");
if (opType->getMemorySpace() != resType->getMemorySpace()) if (opType.getMemorySpace() != resType.getMemorySpace())
return emitOpError( return emitOpError(
"requires input and result memory spaces to be the same"); "requires input and result memory spaces to be the same");
// They must have the same rank, and any specified dimensions must match. // They must have the same rank, and any specified dimensions must match.
if (opType->getRank() != resType->getRank()) if (opType.getRank() != resType.getRank())
return emitOpError("requires input and result ranks to match"); return emitOpError("requires input and result ranks to match");
for (unsigned i = 0, e = opType->getRank(); i != e; ++i) { for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i); int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim) if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match"); return emitOpError("requires static dimensions to match");
} }
@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices()); p->printOperands(getIndices());
*p << ']'; *p << ']';
p->printOptionalAttrDict(getAttrs()); p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRefType(); *p << " : " << getMemRefType();
} }
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *memrefType; MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() || return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) || parser->parseColonType(memrefType) ||
parser->resolveOperand(storeValueInfo, memrefType->getElementType(), parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) || result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands); parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@ -950,19 +950,19 @@ bool StoreOp::verify() const {
return emitOpError("expected a value to store and a memref"); return emitOpError("expected a value to store and a memref");
// Second operand is a memref type. // Second operand is a memref type.
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType()); auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType) if (!memRefType)
return emitOpError("second operand must be a memref"); return emitOpError("second operand must be a memref");
// First operand must have same type as memref element type. // First operand must have same type as memref element type.
if (getValueToStore()->getType() != memRefType->getElementType()) if (getValueToStore()->getType() != memRefType.getElementType())
return emitOpError("first operand must have same type memref element type"); return emitOpError("first operand must have same type memref element type");
if (getNumOperands() != 2 + memRefType->getRank()) if (getNumOperands() != 2 + memRefType.getRank())
return emitOpError("store index operand count not equal to memref rank"); return emitOpError("store index operand count not equal to memref rank");
for (auto *idx : getIndices()) for (auto *idx : getIndices())
if (!idx->getType()->isIndex()) if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type"); return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices. // TODO: Verify we have the right number of indices.
@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool TensorCastOp::verify() const { bool TensorCastOp::verify() const {
auto *opType = dyn_cast<TensorType>(getOperand()->getType()); auto opType = getOperand()->getType().dyn_cast<TensorType>();
auto *resType = dyn_cast<TensorType>(getType()); auto resType = getType().dyn_cast<TensorType>();
if (!opType || !resType) if (!opType || !resType)
return emitOpError("requires input and result types to be tensors"); return emitOpError("requires input and result types to be tensors");
if (opType == resType) if (opType == resType)
return emitOpError("requires the input and result type to be different"); return emitOpError("requires the input and result type to be different");
if (opType->getElementType() != resType->getElementType()) if (opType.getElementType() != resType.getElementType())
return emitOpError( return emitOpError(
"requires input and result element types to be the same"); "requires input and result element types to be the same");
// If the source or destination are unranked, then the cast is valid. // If the source or destination are unranked, then the cast is valid.
auto *opRType = dyn_cast<RankedTensorType>(opType); auto opRType = opType.dyn_cast<RankedTensorType>();
auto *resRType = dyn_cast<RankedTensorType>(resType); auto resRType = resType.dyn_cast<RankedTensorType>();
if (!opRType || !resRType) if (!opRType || !resRType)
return false; return false;
// If they are both ranked, they have to have the same rank, and any specified // If they are both ranked, they have to have the same rank, and any specified
// dimensions must match. // dimensions must match.
if (opRType->getRank() != resRType->getRank()) if (opRType.getRank() != resRType.getRank())
return emitOpError("requires input and result ranks to match"); return emitOpError("requires input and result ranks to match");
for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) { for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i); int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim) if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match"); return emitOpError("requires static dimensions to match");
} }

View File

@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
SmallVector<SSAValue *, 8> existingConstants; SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased. // Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase; std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>; using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
bool foldOperation(Operation *op, bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants, SmallVectorImpl<SSAValue *> &existingConstants,
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) { for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
auto &inst = *instIt++; auto &inst = *instIt++;
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
builder.setInsertionPoint(&inst); builder.setInsertionPoint(&inst);
return builder.create<ConstantOp>(inst.getLoc(), value, type); return builder.create<ConstantOp>(inst.getLoc(), value, type);
}; };
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding. // Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) { void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
MLFuncBuilder builder(stmt); MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type); return builder.create<ConstantOp>(stmt->getLoc(), value, type);
}; };

View File

@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
bInner.setInsertionPoint(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2. // Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * { auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
// Add the leading dimension in the shape for the double buffer. // Add the leading dimension in the shape for the double buffer.
ArrayRef<int> shape = oldMemRefType->getShape(); ArrayRef<int> shape = oldMemRefType.getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end()); SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2); shapeSizes.insert(shapeSizes.begin(), 2);
auto *newMemRefType = auto newMemRefType =
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {}, bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
oldMemRefType->getMemorySpace()); oldMemRefType.getMemorySpace());
return newMemRefType; return newMemRefType;
}; };
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType())); auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
// Create and place the alloc at the top level. // Create and place the alloc at the top level.
MLFuncBuilder topBuilder(forStmt->getFunction()); MLFuncBuilder topBuilder(forStmt->getFunction());
auto *newMemRef = cast<MLValue>( auto newMemRef = cast<MLValue>(
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType) topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult()); ->getResult());

View File

@ -78,7 +78,7 @@ private:
/// As part of canonicalization, we move constants to the top of the entry /// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of /// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for. /// constants we have done this for.
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants; DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
}; };
}; // end anonymous namespace }; // end anonymous namespace

View File

@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
MLValue *newMemRef, MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices, ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) { AffineMap indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank(); unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode (void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank(); unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; (void)newMemRefRank;
if (indexRemap) { if (indexRemap) {
assert(indexRemap.getNumInputs() == oldMemRefRank); assert(indexRemap.getNumInputs() == oldMemRefRank);
@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
} }
// Assert same elemental type. // Assert same elemental type.
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() == assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
cast<MemRefType>(newMemRef->getType())->getElementType()); newMemRef->getType().cast<MemRefType>().getElementType());
// Check if memref was used in a non-deferencing context. // Check if memref was used in a non-deferencing context.
for (const StmtOperand &use : oldMemRef->getUses()) { for (const StmtOperand &use : oldMemRef->getUses()) {
@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
opStmt->operand_end()); opStmt->operand_end());
// Result types don't change. Both memref's are of the same elemental type. // Result types don't change. Both memref's are of the same elemental type.
SmallVector<Type *, 8> resultTypes; SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults()); resultTypes.reserve(opStmt->getNumResults());
for (const auto *result : opStmt->getResults()) for (const auto *result : opStmt->getResults())
resultTypes.push_back(result->getType()); resultTypes.push_back(result->getType());

View File

@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches,
/// sizes specified by vectorSize. The MemRef lives in the same memory space as /// sizes specified by vectorSize. The MemRef lives in the same memory space as
/// tmpl. The MemRef should be promoted to a closer memory address space in a /// tmpl. The MemRef should be promoted to a closer memory address space in a
/// later pass. /// later pass.
static MemRefType *getVectorizedMemRefType(MemRefType *tmpl, static MemRefType getVectorizedMemRefType(MemRefType tmpl,
ArrayRef<int> vectorSizes) { ArrayRef<int> vectorSizes) {
auto *elementType = tmpl->getElementType(); auto elementType = tmpl.getElementType();
assert(!dyn_cast<VectorType>(elementType) && assert(!elementType.dyn_cast<VectorType>() &&
"Can't vectorize an already vector type"); "Can't vectorize an already vector type");
assert(tmpl->getAffineMaps().empty() && assert(tmpl.getAffineMaps().empty() &&
"Unsupported non-implicit identity map"); "Unsupported non-implicit identity map");
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {}, return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
tmpl->getMemorySpace()); tmpl.getMemorySpace());
} }
/// Creates an unaligned load with the following semantics: /// Creates an unaligned load with the following semantics:
@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map; using functional::map;
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType(); return v->getType();
}; };
auto types = map(getType, operands); auto types = map(getType, operands);
@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map; using functional::map;
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType(); return v->getType();
}; };
auto types = map(getType, operands); auto types = map(getType, operands);
@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() {
template <typename LoadOrStoreOpPointer> template <typename LoadOrStoreOpPointer>
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp, static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
ArrayRef<int> vectorSize) { ArrayRef<int> vectorSize) {
auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType()); auto memRefType =
auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); memoryOp->getMemRef()->getType().template cast<MemRefType>();
auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
// Materialize a MemRef with 1 vector. // Materialize a MemRef with 1 vector.
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());