Implement value type abstraction for types.

This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 219372163
This commit is contained in:
River Riddle 2018-10-30 14:59:22 -07:00 committed by jpienaar
parent 75376b8e33
commit 4c465a181d
41 changed files with 998 additions and 811 deletions

View File

@ -51,7 +51,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
/// whether indices[dim] is independent of the value `input`.
// For now we assume no layout map or identity layout map in the MemRef.
// TODO(ntv): support more than identity layout map.
bool isAccessInvariant(const MLValue &input, MemRefType *memRefType,
bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
llvm::ArrayRef<MLValue *> indices, unsigned dim);
/// Checks whether all the LoadOp and StoreOp matched have access indexing

View File

@ -250,9 +250,9 @@ public:
TypeAttr() = default;
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
static TypeAttr get(Type *type, MLIRContext *context);
static TypeAttr get(Type type, MLIRContext *context);
Type *getValue() const;
Type getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Type; }
@ -277,7 +277,7 @@ public:
Function *getValue() const;
FunctionType *getType() const;
FunctionType getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Function; }
@ -294,7 +294,7 @@ public:
ElementsAttr() = default;
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
VectorOrTensorType *getType() const;
VectorOrTensorType getType() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) {
@ -313,7 +313,7 @@ public:
SplatElementsAttr() = default;
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
Attribute getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
@ -335,12 +335,12 @@ public:
/// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary.
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the
// attribute.
static DenseElementsAttr get(VectorOrTensorType *type,
static DenseElementsAttr get(VectorOrTensorType type,
ArrayRef<Attribute> values);
void getValues(SmallVectorImpl<Attribute> &values) const;
@ -410,7 +410,7 @@ public:
OpaqueElementsAttr() = default;
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
StringRef getValue() const;
@ -440,7 +440,7 @@ public:
SparseElementsAttr() = default;
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
static SparseElementsAttr get(VectorOrTensorType *type,
static SparseElementsAttr get(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);

View File

@ -64,10 +64,10 @@ public:
bool args_empty() const { return arguments.empty(); }
/// Add one value to the operand list.
BBArgument *addArgument(Type *type);
BBArgument *addArgument(Type type);
/// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
unsigned getNumArguments() const { return arguments.size(); }
BBArgument *getArgument(unsigned i) { return arguments[i]; }

View File

@ -68,29 +68,28 @@ public:
unsigned column);
// Types.
FloatType *getBF16Type();
FloatType *getF16Type();
FloatType *getF32Type();
FloatType *getF64Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();
FloatType getF64Type();
OtherType *getIndexType();
OtherType *getTFControlType();
OtherType *getTFStringType();
OtherType *getTFResourceType();
OtherType *getTFVariantType();
OtherType *getTFComplex64Type();
OtherType *getTFComplex128Type();
OtherType *getTFF32REFType();
OtherType getIndexType();
OtherType getTFControlType();
OtherType getTFStringType();
OtherType getTFResourceType();
OtherType getTFVariantType();
OtherType getTFComplex64Type();
OtherType getTFComplex128Type();
OtherType getTFF32REFType();
IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
UnrankedTensorType *getTensorType(Type *elementType);
IntegerType getIntegerType(unsigned width);
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
VectorType getVectorType(ArrayRef<int> shape, Type elementType);
RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
UnrankedTensorType getTensorType(Type elementType);
// Attributes.
@ -102,15 +101,15 @@ public:
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type *type);
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(const Function *value);
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data);
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
@ -366,7 +365,7 @@ public:
/// Creates an operation given the fields.
OperationStmt *createOperation(Location *location, OperationName name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> types,
ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs);
/// Create operation of specific op type at the current insertion point.

View File

@ -96,7 +96,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
public:
/// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Attribute value,
Type *type);
Type type);
Attribute getValue() const { return getAttr("value"); }
@ -123,7 +123,7 @@ class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType *type);
const APFloat &value, FloatType type);
APFloat getValue() const {
return getAttrOfType<FloatAttr>("value").getValue();
@ -150,7 +150,7 @@ public:
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
static void build(Builder *builder, OperationState *result, int64_t value,
Type *type);
Type type);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getValue();

View File

@ -27,7 +27,7 @@ namespace mlir {
// blocks, each of which includes instructions.
class CFGFunction : public Function {
public:
CFGFunction(Location *location, StringRef name, FunctionType *type,
CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
~CFGFunction();

View File

@ -66,7 +66,7 @@ public:
}
protected:
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
};
/// Basic block arguments are CFG Values.
@ -87,7 +87,7 @@ public:
private:
friend class BasicBlock; // For access to private constructor.
BBArgument(Type *type, BasicBlock *owner)
BBArgument(Type type, BasicBlock *owner)
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
/// The owner of this operand.
@ -99,7 +99,7 @@ private:
/// Instruction results are CFG Values.
class InstResult : public CFGValue {
public:
InstResult(Type *type, OperationInst *owner)
InstResult(Type type, OperationInst *owner)
: CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
static bool classof(const SSAValue *value) {

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
@ -55,7 +56,7 @@ public:
Identifier getName() const { return nameAndKind.getPointer(); }
/// Return the type of this function.
FunctionType *getType() const { return type; }
FunctionType getType() const { return type; }
/// Returns all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() const;
@ -93,7 +94,7 @@ public:
void emitNote(const Twine &message) const;
protected:
Function(Kind kind, Location *location, StringRef name, FunctionType *type,
Function(Kind kind, Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
~Function();
@ -108,7 +109,7 @@ private:
Location *location;
/// The type of the function.
FunctionType *const type;
FunctionType type;
/// This holds general named attributes for the function.
AttributeListStorage *attrs;
@ -121,7 +122,7 @@ private:
/// defined in some other module.
class ExtFunction : public Function {
public:
ExtFunction(Location *location, StringRef name, FunctionType *type,
ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
/// Methods for support type inquiry through isa, cast, and dyn_cast.

View File

@ -202,7 +202,7 @@ public:
/// Create a new OperationInst with the specified fields.
static OperationInst *create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);

View File

@ -41,7 +41,7 @@ class MLFunction final
public:
/// Creates a new MLFunction with the specific type.
static MLFunction *create(Location *location, StringRef name,
FunctionType *type,
FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
/// Destroys this statement and its subclass data.
@ -52,7 +52,7 @@ public:
//===--------------------------------------------------------------------===//
/// Returns number of arguments.
unsigned getNumArguments() const { return getType()->getInputs().size(); }
unsigned getNumArguments() const { return getType().getInputs().size(); }
/// Gets argument.
MLFuncArgument *getArgument(unsigned idx) {
@ -103,13 +103,13 @@ public:
}
private:
MLFunction(Location *location, StringRef name, FunctionType *type,
MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
return getType()->getInputs().size();
return getType().getInputs().size();
}
// Internal functions to get argument list used by getArgument() methods.

View File

@ -73,7 +73,7 @@ public:
}
protected:
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
};
/// This is the value defined by an argument of an ML function.
@ -93,7 +93,7 @@ public:
private:
friend class MLFunction; // For access to private constructor.
MLFuncArgument(Type *type, MLFunction *owner)
MLFuncArgument(Type type, MLFunction *owner)
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
/// The owner of this operand.
@ -105,7 +105,7 @@ private:
/// This is a value defined by a result of an operation instruction.
class StmtResult : public MLValue {
public:
StmtResult(Type *type, OperationStmt *owner)
StmtResult(Type type, OperationStmt *owner)
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
static bool classof(const SSAValue *value) {

View File

@ -71,13 +71,13 @@ struct constant_int_op_binder {
bool match(Operation *op) {
if (auto constOp = op->dyn_cast<ConstantOp>()) {
auto *type = constOp->getResult()->getType();
auto type = constOp->getResult()->getType();
auto attr = constOp->getAttr("value");
if (isa<IntegerType>(type)) {
if (type.isa<IntegerType>()) {
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
}
if (isa<VectorOrTensorType>(type)) {
if (type.isa<VectorOrTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getValue());

View File

@ -493,7 +493,7 @@ public:
return this->getOperation()->getResult(0);
}
Type *getType() const { return getResult()->getType(); }
Type getType() const { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
@ -539,7 +539,7 @@ public:
return this->getOperation()->getResult(i);
}
Type *getType(unsigned i) const { return getResult(i)->getType(); }
Type getType(unsigned i) const { return getResult(i)->getType(); }
static bool verifyTrait(const Operation *op) {
return impl::verifyNResults(op, N);
@ -565,7 +565,7 @@ public:
return this->getOperation()->getResult(i);
}
Type *getType(unsigned i) const { return getResult(i)->getType(); }
Type getType(unsigned i) const { return getResult(i)->getType(); }
static bool verifyTrait(const Operation *op) {
return impl::verifyAtLeastNResults(op, N);
@ -803,7 +803,7 @@ protected:
// which avoids them being template instantiated/duplicated.
namespace impl {
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
Type *destType);
Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(const Operation *op, OpAsmPrinter *p);
} // namespace impl
@ -819,7 +819,7 @@ class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...> {
public:
static void build(Builder *builder, OperationState *result, SSAValue *source,
Type *destType) {
Type destType) {
impl::buildCastOp(builder, result, source, destType);
}
static bool parse(OpAsmParser *parser, OperationState *result) {

View File

@ -67,7 +67,7 @@ public:
printOperand(*it);
}
}
virtual void printType(const Type *type) = 0;
virtual void printType(Type type) = 0;
virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(Attribute attr) = 0;
virtual void printAffineMap(AffineMap map) = 0;
@ -95,8 +95,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
p.printType(&type);
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
p.printType(type);
return p;
}
@ -163,20 +163,20 @@ public:
virtual bool parseComma() = 0;
/// Parse a colon followed by a type.
virtual bool parseColonType(Type *&result) = 0;
virtual bool parseColonType(Type &result) = 0;
/// Parse a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> bool parseColonType(TypeType *&result) {
template <typename TypeType> bool parseColonType(TypeType &result) {
llvm::SMLoc loc;
getCurrentLocation(&loc);
// Parse any kind of type.
Type *type;
Type type;
if (parseColonType(type))
return true;
// Check for the right kind of attribute.
result = dyn_cast<TypeType>(type);
result = type.dyn_cast<TypeType>();
if (!result) {
emitError(loc, "invalid kind of type specified");
return true;
@ -186,15 +186,15 @@ public:
}
/// Parse a colon followed by a type list, which must have at least one type.
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0;
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
virtual bool parseKeywordType(const char *keyword, Type *&result) = 0;
virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
/// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.
bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) {
bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type);
return false;
}
@ -202,7 +202,7 @@ public:
/// Add the specified types to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end());
return false;
}
@ -288,13 +288,13 @@ public:
/// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure.
virtual bool resolveOperand(const OperandType &operand, Type *type,
virtual bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success.
/// This method should be used when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type,
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<SSAValue *> &result) {
for (auto elt : operands)
if (resolveOperand(elt, type, result))
@ -306,7 +306,7 @@ public:
/// emitting an error and returning true on failure, or appending the results
/// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type *> types, llvm::SMLoc loc,
ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<SSAValue *> &result) {
if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) +
@ -321,7 +321,7 @@ public:
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) = 0;
/// Emit a diagnostic at the specified location and return true.

View File

@ -25,6 +25,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/PointerUnion.h"
#include <memory>
@ -191,7 +192,7 @@ struct OperationState {
OperationName name;
SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation.
SmallVector<Type *, 4> types;
SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes;
public:
@ -202,7 +203,7 @@ public:
: context(context), location(location), name(name) {}
OperationState(MLIRContext *context, Location *location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attributes = {})
: context(context), location(location), name(name, context),
operands(operands.begin(), operands.end()),
@ -213,7 +214,7 @@ public:
operands.append(newOperands.begin(), newOperands.end());
}
void addTypes(ArrayRef<Type *> newTypes) {
void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}

View File

@ -25,7 +25,6 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
class Function;
@ -51,7 +50,7 @@ public:
SSAValueKind getKind() const { return typeAndKind.getInt(); }
Type *getType() const { return typeAndKind.getPointer(); }
Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
@ -93,9 +92,10 @@ public:
void dump() const;
protected:
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
private:
const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind;
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
};
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
@ -127,7 +127,7 @@ public:
inline use_range getUses() const;
protected:
SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {}
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
};
// Utility functions for iterating through SSAValue uses.

View File

@ -44,7 +44,7 @@ public:
/// Create a new OperationStmt with the specific fields.
static OperationStmt *create(Location *location, OperationName name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
@ -329,7 +329,7 @@ public:
//===--------------------------------------------------------------------===//
/// Return the context this operation is associated with.
MLIRContext *getContext() const { return getType()->getContext(); }
MLIRContext *getContext() const { return getType().getContext(); }
using Statement::dump;
using Statement::print;

View File

@ -20,6 +20,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
class AffineMap;
@ -28,6 +29,22 @@ class IntegerType;
class FloatType;
class OtherType;
namespace detail {
class TypeStorage;
class IntegerTypeStorage;
class FloatTypeStorage;
struct OtherTypeStorage;
struct FunctionTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
/// MLIRContext. As such, they are passed around by raw non-const pointer.
///
@ -68,11 +85,34 @@ public:
MemRef,
};
using ImplType = detail::TypeStorage;
Type() : type(nullptr) {}
/* implicit */ Type(const ImplType *type)
: type(const_cast<ImplType *>(type)) {}
Type(const Type &other) : type(other.type) {}
Type &operator=(Type other) {
type = other.type;
return *this;
}
bool operator==(Type other) const { return type == other.type; }
bool operator!=(Type other) const { return !(*this == other); }
explicit operator bool() const { return type; }
bool operator!() const { return type == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
/// Return the classification for this type.
Kind getKind() const { return kind; }
Kind getKind() const;
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const { return context; }
MLIRContext *getContext() const;
// Convenience predicates. This is only for 'other' and floating point types,
// derived types should use isa/dyn_cast.
@ -97,56 +137,42 @@ public:
unsigned getBitWidth() const;
// Convenience factories.
static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
static FloatType *getBF16(MLIRContext *ctx);
static FloatType *getF16(MLIRContext *ctx);
static FloatType *getF32(MLIRContext *ctx);
static FloatType *getF64(MLIRContext *ctx);
static OtherType *getIndex(MLIRContext *ctx);
static OtherType *getTFControl(MLIRContext *ctx);
static OtherType *getTFString(MLIRContext *ctx);
static OtherType *getTFResource(MLIRContext *ctx);
static OtherType *getTFVariant(MLIRContext *ctx);
static OtherType *getTFComplex64(MLIRContext *ctx);
static OtherType *getTFComplex128(MLIRContext *ctx);
static OtherType *getTFF32REF(MLIRContext *ctx);
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static OtherType getIndex(MLIRContext *ctx);
static OtherType getTFControl(MLIRContext *ctx);
static OtherType getTFString(MLIRContext *ctx);
static OtherType getTFResource(MLIRContext *ctx);
static OtherType getTFVariant(MLIRContext *ctx);
static OtherType getTFComplex64(MLIRContext *ctx);
static OtherType getTFComplex128(MLIRContext *ctx);
static OtherType getTFF32REF(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
void dump() const;
friend ::llvm::hash_code hash_value(Type arg);
unsigned getSubclassData() const;
void setSubclassData(unsigned val);
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
return static_cast<const void *>(type);
}
static Type getFromOpaquePointer(const void *pointer) {
return Type((ImplType *)(pointer));
}
protected:
explicit Type(Kind kind, MLIRContext *context)
: context(context), kind(kind), subclassData(0) {}
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
: Type(kind, context) {
setSubclassData(subClassData);
}
~Type() {}
unsigned getSubclassData() const { return subclassData; }
void setSubclassData(unsigned val) {
subclassData = val;
// Ensure we don't have any accidental truncation.
assert(getSubclassData() == val && "Subclass data too large for field");
}
private:
Type(const Type&) = delete;
void operator=(const Type&) = delete;
/// This refers to the MLIRContext in which this type was uniqued.
MLIRContext *const context;
/// Classification of the subclass, used for type checking.
Kind kind : 8;
// Space for subclasses to store data.
unsigned subclassData : 24;
ImplType *type;
};
inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
type.print(os);
return os;
}
@ -154,148 +180,138 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType : public Type {
public:
static IntegerType *get(unsigned width, MLIRContext *context);
using ImplType = detail::IntegerTypeStorage;
IntegerType() = default;
/* implicit */ IntegerType(Type::ImplType *ptr);
static IntegerType get(unsigned width, MLIRContext *context);
/// Return the bitwidth of this integer type.
unsigned getWidth() const {
return width;
}
unsigned getWidth() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Integer;
}
static bool kindof(Kind kind) { return kind == Kind::Integer; }
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
private:
unsigned width;
IntegerType(unsigned width, MLIRContext *context);
~IntegerType() = delete;
};
inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
/// Return true if this is an integer type with the specified width.
inline bool Type::isInteger(unsigned width) const {
if (auto *intTy = dyn_cast<IntegerType>(this))
return intTy->getWidth() == width;
if (auto intTy = dyn_cast<IntegerType>())
return intTy.getWidth() == width;
return false;
}
class FloatType : public Type {
public:
using ImplType = detail::FloatTypeStorage;
FloatType() = default;
/* implicit */ FloatType(Type::ImplType *ptr);
static FloatType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE &&
type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE;
static bool kindof(Kind kind) {
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE;
}
static FloatType *get(Kind kind, MLIRContext *context);
private:
FloatType(Kind kind, MLIRContext *context);
~FloatType() = delete;
};
inline FloatType *Type::getBF16(MLIRContext *ctx) {
inline FloatType Type::getBF16(MLIRContext *ctx) {
return FloatType::get(Kind::BF16, ctx);
}
inline FloatType *Type::getF16(MLIRContext *ctx) {
inline FloatType Type::getF16(MLIRContext *ctx) {
return FloatType::get(Kind::F16, ctx);
}
inline FloatType *Type::getF32(MLIRContext *ctx) {
inline FloatType Type::getF32(MLIRContext *ctx) {
return FloatType::get(Kind::F32, ctx);
}
inline FloatType *Type::getF64(MLIRContext *ctx) {
inline FloatType Type::getF64(MLIRContext *ctx) {
return FloatType::get(Kind::F64, ctx);
}
/// This is a type for the random collection of special base types.
class OtherType : public Type {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() >= Kind::FIRST_OTHER_TYPE &&
type->getKind() <= Kind::LAST_OTHER_TYPE;
}
static OtherType *get(Kind kind, MLIRContext *context);
using ImplType = detail::OtherTypeStorage;
OtherType() = default;
/* implicit */ OtherType(Type::ImplType *ptr);
private:
OtherType(Kind kind, MLIRContext *context);
~OtherType() = delete;
static OtherType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
}
};
inline OtherType *Type::getIndex(MLIRContext *ctx) {
inline OtherType Type::getIndex(MLIRContext *ctx) {
return OtherType::get(Kind::Index, ctx);
}
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
inline OtherType Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
}
inline OtherType *Type::getTFResource(MLIRContext *ctx) {
inline OtherType Type::getTFResource(MLIRContext *ctx) {
return OtherType::get(Kind::TFResource, ctx);
}
inline OtherType *Type::getTFString(MLIRContext *ctx) {
inline OtherType Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx);
}
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
inline OtherType Type::getTFVariant(MLIRContext *ctx) {
return OtherType::get(Kind::TFVariant, ctx);
}
inline OtherType *Type::getTFComplex64(MLIRContext *ctx) {
inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex64, ctx);
}
inline OtherType *Type::getTFComplex128(MLIRContext *ctx) {
inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex128, ctx);
}
inline OtherType *Type::getTFF32REF(MLIRContext *ctx) {
inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
return OtherType::get(Kind::TFF32REF, ctx);
}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
public:
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
MLIRContext *context);
using ImplType = detail::FunctionTypeStorage;
FunctionType() = default;
/* implicit */ FunctionType(Type::ImplType *ptr);
static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
MLIRContext *context);
// Input types.
unsigned getNumInputs() const { return getSubclassData(); }
Type *getInput(unsigned i) const { return getInputs()[i]; }
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type*> getInputs() const {
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
}
ArrayRef<Type> getInputs() const;
// Result types.
unsigned getNumResults() const { return numResults; }
unsigned getNumResults() const;
Type *getResult(unsigned i) const { return getResults()[i]; }
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type*> getResults() const {
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
}
ArrayRef<Type> getResults() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Function;
}
private:
unsigned numResults;
Type *const *inputsAndResults;
FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context);
~FunctionType() = delete;
static bool kindof(Kind kind) { return kind == Kind::Function; }
};
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
/// types, because many operations work on values of these aggregate types.
class VectorOrTensorType : public Type {
public:
Type *getElementType() const { return elementType; }
using ImplType = detail::VectorOrTensorTypeStorage;
VectorOrTensorType() = default;
/* implicit */ VectorOrTensorType(Type::ImplType *ptr);
Type getElementType() const;
/// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor or vector, abort.
@ -319,56 +335,40 @@ public:
int getDimSize(unsigned i) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Vector ||
type->getKind() == Kind::RankedTensor ||
type->getKind() == Kind::UnrankedTensor;
static bool kindof(Kind kind) {
return kind == Kind::Vector || kind == Kind::RankedTensor ||
kind == Kind::UnrankedTensor;
}
public:
Type *elementType;
VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
unsigned subClassData = 0);
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType {
public:
static VectorType *get(ArrayRef<int> shape, Type *elementType);
using ImplType = detail::VectorTypeStorage;
VectorType() = default;
/* implicit */ VectorType(Type::ImplType *ptr);
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
static VectorType get(ArrayRef<int> shape, Type elementType);
ArrayRef<int> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Vector;
}
private:
const int *shapeElements;
Type *elementType;
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
~VectorType() = delete;
static bool kindof(Kind kind) { return kind == Kind::Vector; }
};
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public VectorOrTensorType {
public:
using ImplType = detail::TensorTypeStorage;
TensorType() = default;
/* implicit */ TensorType(Type::ImplType *ptr);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::RankedTensor ||
type->getKind() == Kind::UnrankedTensor;
static bool kindof(Kind kind) {
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
}
protected:
TensorType(Kind kind, Type *elementType, MLIRContext *context);
~TensorType() {}
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
@ -376,40 +376,30 @@ protected:
/// integer or unknown (represented -1).
class RankedTensorType : public TensorType {
public:
static RankedTensorType *get(ArrayRef<int> shape,
Type *elementType);
using ImplType = detail::RankedTensorTypeStorage;
RankedTensorType() = default;
/* implicit */ RankedTensorType(Type::ImplType *ptr);
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
static bool classof(const Type *type) {
return type->getKind() == Kind::RankedTensor;
}
ArrayRef<int> getShape() const;
private:
const int *shapeElements;
RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context);
~RankedTensorType() = delete;
static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
};
/// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape.
class UnrankedTensorType : public TensorType {
public:
static UnrankedTensorType *get(Type *elementType);
using ImplType = detail::UnrankedTensorTypeStorage;
UnrankedTensorType() = default;
/* implicit */ UnrankedTensorType(Type::ImplType *ptr);
static UnrankedTensorType get(Type elementType);
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool classof(const Type *type) {
return type->getKind() == Kind::UnrankedTensor;
}
private:
UnrankedTensorType(Type *elementType, MLIRContext *context);
~UnrankedTensorType() = delete;
static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
};
/// MemRef types represent a region of memory that have a shape with a fixed
@ -418,62 +408,96 @@ private:
/// affine map composition, represented as an array AffineMap pointers.
class MemRefType : public Type {
public:
using ImplType = detail::MemRefTypeStorage;
MemRefType() = default;
/* implicit */ MemRefType(Type::ImplType *ptr);
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space.
static MemRefType *get(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
static MemRefType get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
unsigned getRank() const { return getShape().size(); }
/// Returns an array of memref shape dimension sizes.
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
ArrayRef<int> getShape() const;
/// Return the size of the specified dimension, or -1 if unspecified.
int getDimSize(unsigned i) const { return getShape()[i]; }
/// Returns the elemental type for this memref shape.
Type *getElementType() const { return elementType; }
Type getElementType() const;
/// Returns an array of affine map pointers representing the memref affine
/// map composition.
ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const { return memorySpace; }
unsigned getMemorySpace() const;
/// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const;
static bool classof(const Type *type) {
return type->getKind() == Kind::MemRef;
}
private:
/// The type of each scalar element of the memref.
Type *elementType;
/// An array of integers which stores the shape dimension sizes.
const int *shapeElements;
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
AffineMap const *affineMapList;
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context);
~MemRefType() = delete;
static bool kindof(Kind kind) { return kind == Kind::MemRef; }
};
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type *type) {
return isa<FloatType>(type) || isa<VectorType>(type) ||
isa<IntegerType>(type) || isa<OtherType>(type);
// Make Type hashable.
inline ::llvm::hash_code hash_value(Type arg) {
return ::llvm::hash_value(arg.type);
}
template <typename U> bool Type::isa() const {
assert(type && "isa<> used on a null type.");
return U::kindof(getKind());
}
template <typename U> U Type::dyn_cast() const {
return isa<U>() ? U(type) : U(nullptr);
}
template <typename U> U Type::dyn_cast_or_null() const {
return (type && isa<U>()) ? U(type) : U(nullptr);
}
template <typename U> U Type::cast() const {
assert(isa<U>());
return U(type);
}
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type type) {
return type.isa<FloatType>() || type.isa<VectorType>() ||
type.isa<IntegerType>() || type.isa<OtherType>();
}
} // end namespace mlir
namespace llvm {
// Type hash just like pointers.
template <> struct DenseMapInfo<mlir::Type> {
static mlir::Type getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
}
static mlir::Type getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
};
/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
template <> struct PointerLikeTypeTraits<mlir::Type> {
public:
static inline void *getAsVoidPointer(mlir::Type I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::Type getFromVoidPointer(void *P) {
return mlir::Type::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif // MLIR_IR_TYPES_H

View File

@ -104,15 +104,15 @@ class AllocOp
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
/// The result of an alloc is always a MemRefType.
MemRefType *getType() const {
return cast<MemRefType>(getResult()->getType());
MemRefType getType() const {
return getResult()->getType().cast<MemRefType>();
}
static StringRef getOperationName() { return "alloc"; }
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
MemRefType *memrefType, ArrayRef<SSAValue *> operands = {});
MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@ -276,7 +276,7 @@ public:
const SSAValue *getSrcMemRef() const { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() const {
return cast<MemRefType>(getSrcMemRef()->getType())->getRank();
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
}
// Returns the source memerf indices for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator>
@ -291,13 +291,13 @@ public:
}
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() const {
return cast<MemRefType>(getDstMemRef()->getType())->getRank();
return getDstMemRef()->getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() const {
return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace();
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
unsigned getDstMemorySpace() const {
return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace();
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
}
// Returns the destination memref indices for this DMA operation.
@ -387,7 +387,7 @@ public:
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() const {
return cast<MemRefType>(getTagMemRef()->getType())->getRank();
return getTagMemRef()->getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
@ -460,8 +460,8 @@ public:
SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); }
MemRefType *getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType());
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}
llvm::iterator_range<Operation::operand_iterator> getIndices() {
@ -508,8 +508,8 @@ public:
static StringRef getOperationName() { return "memref_cast"; }
/// The result of a memref_cast is always a memref.
MemRefType *getType() const {
return cast<MemRefType>(getResult()->getType());
MemRefType getType() const {
return getResult()->getType().cast<MemRefType>();
}
bool verify() const;
@ -583,8 +583,8 @@ public:
SSAValue *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); }
MemRefType *getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType());
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}
llvm::iterator_range<Operation::operand_iterator> getIndices() {
@ -671,8 +671,8 @@ public:
static StringRef getOperationName() { return "tensor_cast"; }
/// The result of a tensor_cast is always a tensor.
TensorType *getType() const {
return cast<TensorType>(getResult()->getType());
TensorType getType() const {
return getResult()->getType().cast<TensorType>();
}
bool verify() const;

View File

@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
return tripCountExpr.getLargestKnownDivisor();
}
bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
ArrayRef<MLValue *> indices, unsigned dim) {
assert(indices.size() == memRefType->getRank());
assert(indices.size() == memRefType.getRank());
assert(dim < indices.size());
auto layoutMap = memRefType->getAffineMaps();
assert(memRefType->getAffineMaps().size() <= 1);
auto layoutMap = memRefType.getAffineMaps();
assert(memRefType.getAffineMaps().size() <= 1);
// TODO(ntv): remove dependency on Builder once we support non-identity
// layout map.
Builder b(memRefType->getContext());
Builder b(memRefType.getContext());
assert(layoutMap.empty() ||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
(void)layoutMap;
@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input,
using namespace functional;
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
memoryOp->getIndices());
auto *memRefType = memoryOp->getMemRefType();
auto memRefType = memoryOp->getMemRefType();
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
if (fastestVaryingDim == (numIndices - 1) - d) {
continue;
@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input,
template <typename LoadOrStoreOpPointer>
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
auto *memRefType = memoryOp->getMemRefType();
return isa<VectorType>(memRefType->getElementType());
auto memRefType = memoryOp->getMemRefType();
return memRefType.getElementType().template isa<VectorType>();
}
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {

View File

@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() {
// Verify that the argument list of the function and the arg list of the first
// block line up.
auto fnInputTypes = fn.getType()->getInputs();
auto fnInputTypes = fn.getType().getInputs();
if (fnInputTypes.size() != firstBB->getNumArguments())
return failure("first block of cfgfunc must have " +
Twine(fnInputTypes.size()) +
@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands,
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
// Verify that the return operands match the results of the function.
auto results = fn.getType()->getResults();
auto results = fn.getType().getResults();
if (inst.getNumOperands() != results.size())
return failure("return has " + Twine(inst.getNumOperands()) +
" operands, but enclosing function returns " +

View File

@ -122,7 +122,7 @@ private:
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type);
void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const Operation *op);
@ -135,16 +135,16 @@ private:
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(const Type *type) {
if (auto *funcType = dyn_cast<FunctionType>(type)) {
void ModuleState::visitType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
for (auto *input : funcType->getInputs())
for (auto input : funcType.getInputs())
visitType(input);
for (auto *result : funcType->getResults())
for (auto result : funcType.getResults())
visitType(result);
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
} else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
for (auto map : memref->getAffineMaps()) {
for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map);
}
}
@ -271,7 +271,7 @@ public:
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(Attribute attr);
void printType(const Type *type);
void printType(Type type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
@ -290,7 +290,7 @@ protected:
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type);
void printFunctionResultType(FunctionType type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
auto *type = attr.getType();
auto shape = type->getShape();
auto rank = type->getRank();
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
SmallVector<Attribute, 16> elements;
attr.getValues(elements);
@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
os << ']';
}
void ModulePrinter::printType(const Type *type) {
switch (type->getKind()) {
void ModulePrinter::printType(Type type) {
switch (type.getKind()) {
case Type::Kind::Index:
os << "index";
return;
@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
return;
case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type);
os << 'i' << integer->getWidth();
auto integer = type.cast<IntegerType>();
os << 'i' << integer.getWidth();
return;
}
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
auto func = type.cast<FunctionType>();
os << '(';
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> ";
auto results = func->getResults();
auto results = func.getResults();
if (results.size() == 1)
os << *results[0];
os << results[0];
else {
os << '(';
interleaveComma(results, [&](Type *type) { printType(type); });
interleaveComma(results, [&](Type type) { printType(type); });
os << ')';
}
return;
}
case Type::Kind::Vector: {
auto *v = cast<VectorType>(type);
auto v = type.cast<VectorType>();
os << "vector<";
for (auto dim : v->getShape())
for (auto dim : v.getShape())
os << dim << 'x';
os << *v->getElementType() << '>';
os << v.getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
auto *v = cast<RankedTensorType>(type);
auto v = type.cast<RankedTensorType>();
os << "tensor<";
for (auto dim : v->getShape()) {
for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
os << *v->getElementType() << '>';
os << v.getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(type);
auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x";
printType(v->getElementType());
printType(v.getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
auto *v = cast<MemRefType>(type);
auto v = type.cast<MemRefType>();
os << "memref<";
for (auto dim : v->getShape()) {
for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
printType(v->getElementType());
for (auto map : v->getAffineMaps()) {
printType(v.getElementType());
for (auto map : v.getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
// Only print the memory space if it is the non-default one.
if (v->getMemorySpace())
os << ", " << v->getMemorySpace();
if (v.getMemorySpace())
os << ", " << v.getMemorySpace();
os << '>';
return;
}
@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
// Function printing
//===----------------------------------------------------------------------===//
void ModulePrinter::printFunctionResultType(const FunctionType *type) {
switch (type->getResults().size()) {
void ModulePrinter::printFunctionResultType(FunctionType type) {
switch (type.getResults().size()) {
case 0:
break;
case 1:
os << " -> ";
printType(type->getResults()[0]);
printType(type.getResults()[0]);
break;
default:
os << " -> (";
interleaveComma(type->getResults(),
[&](Type *eltType) { printType(eltType); });
interleaveComma(type.getResults(),
[&](Type eltType) { printType(eltType); });
os << ')';
break;
}
@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
interleaveComma(type->getInputs(),
[&](Type *eltType) { printType(eltType); });
interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
@ -937,7 +936,7 @@ public:
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
void printType(const Type *type) { ModulePrinter::printType(type); }
void printType(Type type) { ModulePrinter::printType(type); }
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
@ -974,10 +973,10 @@ protected:
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
if (intOp->getType()->isInteger(1)) {
if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).printType(this);
ModulePrinter(os, state).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
Type *value;
Type value;
};
/// An attribute representing a reference to a function.
@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage {
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
VectorOrTensorType *type;
VectorOrTensorType type;
};
/// An attribute representing a reference to a vector or tensor constant,

View File

@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const {
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Type *TypeAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
VectorOrTensorType *ElementsAttr::getType() const {
VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
auto elementNum = getType().getNumElements();
auto context = getType().getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
auto elementNum = getType().getNumElements();
auto context = getType().getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);

View File

@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() {
// Argument list management.
//===----------------------------------------------------------------------===//
BBArgument *BasicBlock::addArgument(Type *type) {
BBArgument *BasicBlock::addArgument(Type type) {
auto *arg = new BBArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
auto BasicBlock::addArguments(ArrayRef<Type *> types)
auto BasicBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
for (auto *type : types) {
for (auto type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};

View File

@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename,
// Types.
//===----------------------------------------------------------------------===//
FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
FloatType Builder::getBF16Type() { return Type::getBF16(context); }
FloatType *Builder::getF16Type() { return Type::getF16(context); }
FloatType Builder::getF16Type() { return Type::getF16(context); }
FloatType *Builder::getF32Type() { return Type::getF32(context); }
FloatType Builder::getF32Type() { return Type::getF32(context); }
FloatType *Builder::getF64Type() { return Type::getF64(context); }
FloatType Builder::getF64Type() { return Type::getF64(context); }
OtherType *Builder::getIndexType() { return Type::getIndex(context); }
OtherType Builder::getIndexType() { return Type::getIndex(context); }
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
OtherType *Builder::getTFComplex64Type() {
OtherType Builder::getTFComplex64Type() {
return Type::getTFComplex64(context);
}
OtherType *Builder::getTFComplex128Type() {
OtherType Builder::getTFComplex128Type() {
return Type::getTFComplex128(context);
}
OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType *Builder::getIntegerType(unsigned width) {
IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}
FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results) {
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
ArrayRef<Type> results) {
return FunctionType::get(inputs, results, context);
}
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
return VectorType::get(shape, elementType);
}
RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
Type *elementType) {
RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
return RankedTensorType::get(shape, elementType);
}
UnrankedTensorType *Builder::getTensorType(Type *elementType) {
UnrankedTensorType Builder::getTensorType(Type elementType) {
return UnrankedTensorType::get(elementType);
}
@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
return IntegerSetAttr::get(set);
}
TypeAttr Builder::getTypeAttr(Type *type) {
TypeAttr Builder::getTypeAttr(Type type) {
return TypeAttr::get(type, context);
}
@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values);
}
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
}
@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
OperationStmt *MLFuncBuilder::createOperation(Location *location,
OperationName name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> types,
ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs) {
auto *op = OperationStmt::create(location, name, operands, types, attrs,
getContext());

View File

@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
numDims = opInfos.size();
// Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getIndexType();
auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getIndexType();
auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result,
Attribute value, Type *type) {
Attribute value, Type type) {
result->addAttribute("value", value);
result->types.push_back(type);
}
@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const {
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!getValue().isa<FunctionAttr>())
*p << " : " << *getType();
*p << " : " << getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
Type *type;
Type type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes))
@ -208,33 +208,33 @@ bool ConstantOp::verify() const {
if (!value)
return emitOpError("requires a 'value' attribute");
auto *type = this->getType();
if (isa<IntegerType>(type) || type->isIndex()) {
auto type = this->getType();
if (type.isa<IntegerType>() || type.isIndex()) {
if (!value.isa<IntegerAttr>())
return emitOpError(
"requires 'value' to be an integer for an integer result type");
return false;
}
if (isa<FloatType>(type)) {
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
if (isa<VectorOrTensorType>(type)) {
if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
if (type->isTFString()) {
if (type.isTFString()) {
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");
return false;
}
if (isa<FunctionType>(type)) {
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
const APFloat &value, FloatType *type) {
const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
isa<FloatType>(op->getResult(0)->getType());
op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
isa<IntegerType>(op->getResult(0)->getType());
op->getResult(0)->getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState *result,
@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, Type *type) {
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
int64_t value, Type type) {
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
}
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState *result,
@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result,
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type *, 2> types;
SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
@ -330,7 +330,7 @@ bool ReturnOp::verify() const {
// The operand number and types must match the function signature.
MLFunction *function = cast<MLFunction>(block);
const auto &results = function->getType()->getResults();
const auto &results = function->getType().getResults();
if (stmt->getNumOperands() != results.size())
return emitOpError("has " + Twine(stmt->getNumOperands()) +
" operands, but enclosing function returns " +

View File

@ -28,8 +28,8 @@
using namespace mlir;
Function::Function(Kind kind, Location *location, StringRef name,
FunctionType *type, ArrayRef<NamedAttribute> attrs)
: nameAndKind(Identifier::get(name, type->getContext()), kind),
FunctionType type, ArrayRef<NamedAttribute> attrs)
: nameAndKind(Identifier::get(name, type.getContext()), kind),
location(location), type(type) {
this->attrs = AttributeListStorage::get(attrs, getContext());
}
@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
return {};
}
MLIRContext *Function::getContext() const { return getType()->getContext(); }
MLIRContext *Function::getContext() const { return getType().getContext(); }
/// Delete this object.
void Function::destroy() {
@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const {
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::ExtFunc, location, name, type, attrs) {}
@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::CFGFunc, location, name, type, attrs) {}
@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() {
/// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location *location, StringRef name,
FunctionType *type,
FunctionType type,
ArrayRef<NamedAttribute> attrs) {
const auto &argTypes = type->getInputs();
const auto &argTypes = type.getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
return function;
}
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {}

View File

@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const {
/// Create a new OperationInst with the specified fields.
OperationInst *OperationInst::create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name,
OperationInst *OperationInst::clone() const {
SmallVector<CFGValue *, 8> operands;
SmallVector<Type *, 8> resultTypes;
SmallVector<Type, 8> resultTypes;
// Put together the operands and results.
for (auto *operand : getOperands())

View File

@ -21,6 +21,7 @@
#include "AttributeDetail.h"
#include "AttributeListStorage.h"
#include "IntegerSetDetail.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@ -44,11 +45,11 @@ using namespace mlir::detail;
using namespace llvm;
namespace {
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
// Functions are uniqued based on their inputs and results.
using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
using DenseMapInfo<FunctionType *>::getHashValue;
using DenseMapInfo<FunctionType *>::isEqual;
using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
using DenseMapInfo<FunctionTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
}
};
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
// Vectors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type *, ArrayRef<int>>;
using DenseMapInfo<VectorType *>::getHashValue;
using DenseMapInfo<VectorType *>::isEqual;
using KeyTy = std::pair<Type, ArrayRef<int>>;
using DenseMapInfo<VectorTypeStorage *>::getHashValue;
using DenseMapInfo<VectorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
DenseMapInfo<Type *>::getHashValue(key.first),
DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
// Ranked tensors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type *, ArrayRef<int>>;
using DenseMapInfo<RankedTensorType *>::getHashValue;
using DenseMapInfo<RankedTensorType *>::isEqual;
using KeyTy = std::pair<Type, ArrayRef<int>>;
using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
DenseMapInfo<Type *>::getHashValue(key.first),
DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
// MemRefs are uniqued based on their element type, shape, affine map
// composition, and memory space.
using KeyTy =
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
using DenseMapInfo<MemRefType *>::getHashValue;
using DenseMapInfo<MemRefType *>::isEqual;
using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
using DenseMapInfo<MemRefTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
std::get<3>(key));
}
static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
rhs->getAffineMaps(), rhs->getMemorySpace());
return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
rhs->getAffineMaps(), rhs->memorySpace);
}
};
@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
};
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
using KeyTy = std::pair<VectorOrTensorType, StringRef>;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
@ -295,13 +295,14 @@ public:
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
// Uniquing table for 'other' types.
OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
nullptr};
// Uniquing table for 'float' types.
FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
nullptr};
FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
{nullptr};
// Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@ -324,26 +325,26 @@ public:
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
/// Integer type uniquing.
DenseMap<unsigned, IntegerType *> integers;
DenseMap<unsigned, IntegerTypeStorage *> integers;
/// Function type uniquing.
using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
FunctionTypeSet functions;
/// Vector type uniquing.
using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
VectorTypeSet vectors;
/// Ranked tensor type uniquing.
using RankedTensorTypeSet =
DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
RankedTensorTypeSet rankedTensors;
/// Unranked tensor type uniquing.
DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
/// MemRef type uniquing.
using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
MemRefTypeSet memrefs;
// Attribute uniquing.
@ -355,13 +356,12 @@ public:
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
DenseMap<Type, TypeAttributeStorage *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
SplatElementsAttributeStorage *>
DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
splatElementsAttrs;
using DenseElementsAttrSet =
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@ -369,7 +369,7 @@ public:
using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs;
DenseMap<std::tuple<Type *, Attribute, Attribute>,
DenseMap<std::tuple<Type, Attribute, Attribute>,
SparseElementsAttributeStorage *>
sparseElementsAttrs;
@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line,
// Type uniquing
//===----------------------------------------------------------------------===//
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
auto &impl = context->getImpl();
auto *&result = impl.integers[width];
if (!result) {
result = impl.allocator.Allocate<IntegerType>();
new (result) IntegerType(width, context);
result = impl.allocator.Allocate<IntegerTypeStorage>();
new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
}
return result;
}
FloatType *FloatType::get(Kind kind, MLIRContext *context) {
FloatType FloatType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
auto &impl = context->getImpl();
@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) {
return entry;
// On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<FloatType>();
auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
// Initialize the memory using placement new.
new (ptr) FloatType(kind, context);
new (ptr) FloatTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
OtherType *OtherType::get(Kind kind, MLIRContext *context) {
OtherType OtherType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
"Not an 'other' type kind");
auto &impl = context->getImpl();
@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) {
return entry;
// On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<OtherType>();
auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
// Initialize the memory using placement new.
new (ptr) OtherType(kind, context);
new (ptr) OtherTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
ArrayRef<Type *> results,
MLIRContext *context) {
FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
MLIRContext *context) {
auto &impl = context->getImpl();
// Look to see if we already have this function type.
@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<FunctionType>();
auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
// Copy the inputs and results into the bump pointer.
SmallVector<Type *, 16> types;
SmallVector<Type, 16> types;
types.reserve(inputs.size() + results.size());
types.append(inputs.begin(), inputs.end());
types.append(results.begin(), results.end());
auto typesList = impl.copyInto(ArrayRef<Type *>(types));
auto typesList = impl.copyInto(ArrayRef<Type>(types));
// Initialize the memory using placement new.
new (result)
FunctionType(typesList.data(), inputs.size(), results.size(), context);
new (result) FunctionTypeStorage{
{Kind::Function, context, static_cast<unsigned int>(inputs.size())},
static_cast<unsigned int>(results.size()),
typesList.data()};
// Cache and return it.
return *existing.first = result;
}
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
"vectors elements must be primitives");
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
return i < 0;
}) && "vector types must have static shape");
auto *context = elementType->getContext();
auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this vector type.
@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<VectorType>();
auto *result = impl.allocator.Allocate<VectorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
new (result) VectorType(shape, elementType, context);
new (result) VectorTypeStorage{
{{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
elementType},
shape.data()};
// Cache and return it.
return *existing.first = result;
}
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
Type *elementType) {
auto *context = elementType->getContext();
RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this ranked tensor type.
@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<RankedTensorType>();
auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
new (result) RankedTensorType(shape, elementType, context);
new (result) RankedTensorTypeStorage{
{{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
elementType}},
shape.data()};
// Cache and return it.
return *existing.first = result;
}
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
auto *context = elementType->getContext();
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
return result;
// On the first use, we allocate them into the bump pointer.
result = impl.allocator.Allocate<UnrankedTensorType>();
result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
// Initialize the memory using placement new.
new (result) UnrankedTensorType(elementType, context);
new (result) UnrankedTensorTypeStorage{
{{{Kind::UnrankedTensor, context}, elementType}}};
return result;
}
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
auto *context = elementType->getContext();
MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Drop the unbounded identity maps from the composition.
@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<MemRefType>();
auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
// Initialize the memory using placement new.
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
context);
new (result) MemRefTypeStorage{
{Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
elementType,
shape.data(),
static_cast<unsigned int>(affineMapComposition.size()),
affineMapComposition.data(),
memorySpace};
// Cache and return it.
return *existing.first = result;
}
@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
return result;
}
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type];
if (result)
return result;
@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
auto &impl = type->getContext()->getImpl();
auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
return result;
}
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
(void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires");
auto &impl = type->getContext()->getImpl();
auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
DenseElementsAttrInfo::KeyTy key({type, data});
@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
auto *eltType = type->getElementType();
switch (eltType->getKind()) {
auto eltType = type.getElementType();
switch (eltType.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result;
}
case Type::Kind::Integer: {
auto width = ::cast<IntegerType>(eltType)->getWidth();
auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
}
}
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
StringRef bytes) {
assert(isValidTensorElementType(type->getElementType()) &&
assert(isValidTensorElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
auto &impl = type->getContext()->getImpl();
auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
return *existing.first = result;
}
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
auto &impl = type->getContext()->getImpl();
auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto key = std::make_tuple(type, indices, values);

View File

@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
}
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
auto *type = op->getResult(0)->getType();
auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type)
return op->emitOpError(
@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
/// If this is a vector type, or a tensor type, return the scalar element type
/// that it is built around, otherwise return the type unmodified.
static Type *getTensorOrVectorElementType(Type *type) {
if (auto *vec = dyn_cast<VectorType>(type))
return vec->getElementType();
static Type getTensorOrVectorElementType(Type type) {
if (auto vec = type.dyn_cast<VectorType>())
return vec.getElementType();
// Look through tensor<vector<...>> to find the underlying element type.
if (auto *tensor = dyn_cast<TensorType>(type))
return getTensorOrVectorElementType(tensor->getElementType());
if (auto tensor = type.dyn_cast<TensorType>())
return getTensorOrVectorElementType(tensor.getElementType());
return type;
}
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
for (auto *result : op->getResults()) {
if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
return op->emitOpError("requires a floating point type");
}
@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
for (auto *result : op->getResults()) {
if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
return op->emitOpError("requires an integer type");
}
return false;
@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result,
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
Type type;
return parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
*p << " : " << *op->getResult(0)->getType();
*p << " : " << op->getResult(0)->getType();
}
//===----------------------------------------------------------------------===//
@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
//===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result,
SSAValue *source, Type *destType) {
SSAValue *source, Type destType) {
result->addOperands(source);
result->addTypes(destType);
}
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo;
Type *srcType, *dstType;
Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
<< *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
<< op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
}

View File

@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block,
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location *location, OperationName name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const {
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
return getResult(0)->getType()->getContext();
return getResult(0)->getType().getContext();
if (getNumOperands())
return getOperand(0)->getType()->getContext();
return getOperand(0)->getType().getContext();
// In the very odd case where we have no operands or results, fall back to
// doing a find.
@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const {
if (operands.empty())
return findFunction()->getContext();
return getOperand(0)->getType()->getContext();
return getOperand(0)->getType().getContext();
}
//===----------------------------------------------------------------------===//
@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
operands.push_back(remapOperand(opValue));
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
SmallVector<Type *, 8> resultTypes;
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());

126
mlir/lib/IR/TypeDetail.h Normal file
View File

@ -0,0 +1,126 @@
//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of Type.
//
//===----------------------------------------------------------------------===//
#ifndef TYPEDETAIL_H_
#define TYPEDETAIL_H_
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
namespace mlir {
class AffineMap;
class MLIRContext;
namespace detail {
/// Base storage class appearing in a Type.
struct alignas(8) TypeStorage {
TypeStorage(Type::Kind kind, MLIRContext *context)
: context(context), kind(kind), subclassData(0) {}
TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
: context(context), kind(kind), subclassData(subclassData) {}
unsigned getSubclassData() const { return subclassData; }
void setSubclassData(unsigned val) {
subclassData = val;
// Ensure we don't have any accidental truncation.
assert(getSubclassData() == val && "Subclass data too large for field");
}
/// This refers to the MLIRContext in which this type was uniqued.
MLIRContext *const context;
/// Classification of the subclass, used for type checking.
Type::Kind kind : 8;
/// Space for subclasses to store data.
unsigned subclassData : 24;
};
struct IntegerTypeStorage : public TypeStorage {
unsigned width;
};
struct FloatTypeStorage : public TypeStorage {};
struct OtherTypeStorage : public TypeStorage {};
struct FunctionTypeStorage : public TypeStorage {
ArrayRef<Type> getInputs() const {
return ArrayRef<Type>(inputsAndResults, subclassData);
}
ArrayRef<Type> getResults() const {
return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
}
unsigned numResults;
Type const *inputsAndResults;
};
struct VectorOrTensorTypeStorage : public TypeStorage {
Type elementType;
};
struct VectorTypeStorage : public VectorOrTensorTypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
const int *shapeElements;
};
struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
struct RankedTensorTypeStorage : public TensorTypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
const int *shapeElements;
};
struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
struct MemRefTypeStorage : public TypeStorage {
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
ArrayRef<AffineMap> getAffineMaps() const {
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
}
/// The type of each scalar element of the memref.
Type elementType;
/// An array of integers which stores the shape dimension sizes.
const int *shapeElements;
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
AffineMap const *affineMapList;
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
};
} // namespace detail
} // namespace mlir
#endif // TYPEDETAIL_H_

View File

@ -16,10 +16,17 @@
// =============================================================================
#include "mlir/IR/Types.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
Type::Kind Type::getKind() const { return type->kind; }
MLIRContext *Type::getContext() const { return type->context; }
unsigned Type::getBitWidth() const {
switch (getKind()) {
@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const {
case Type::Kind::F64:
return 64;
case Type::Kind::Integer:
return cast<IntegerType>(this)->getWidth();
return cast<IntegerType>().getWidth();
case Type::Kind::Vector:
case Type::Kind::RankedTensor:
case Type::Kind::UnrankedTensor:
return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
return cast<VectorOrTensorType>().getElementType().getBitWidth();
// TODO: Handle more types.
default:
llvm_unreachable("unexpected type");
}
}
IntegerType::IntegerType(unsigned width, MLIRContext *context)
: Type(Kind::Integer, context), width(width) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
unsigned IntegerType::getWidth() const {
return static_cast<ImplType *>(type)->width;
}
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context)
: Type(Kind::Function, context, numInputs), numResults(numResults),
inputsAndResults(inputsAndResults) {}
FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
Type *elementType, unsigned subClassData)
: Type(kind, context, subClassData), elementType(elementType) {}
ArrayRef<Type> FunctionType::getInputs() const {
return static_cast<ImplType *>(type)->getInputs();
}
unsigned FunctionType::getNumResults() const {
return static_cast<ImplType *>(type)->numResults;
}
ArrayRef<Type> FunctionType::getResults() const {
return static_cast<ImplType *>(type)->getResults();
}
VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
Type VectorOrTensorType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
return cast<VectorType>(this)->getShape();
return cast<VectorType>().getShape();
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape();
return cast<RankedTensorType>().getShape();
case Kind::UnrankedTensor:
return cast<RankedTensorType>(this)->getShape();
return cast<RankedTensorType>().getShape();
default:
llvm_unreachable("not a VectorOrTensorType");
}
@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const {
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {}
VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: VectorOrTensorType(kind, context, elementType) {
assert(isValidTensorElementType(elementType));
ArrayRef<int> VectorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: TensorType(Kind::RankedTensor, elementType, context),
shapeElements(shape.data()) {
setSubclassData(shape.size());
TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
ArrayRef<int> RankedTensorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
: TensorType(Kind::UnrankedTensor, elementType, context) {}
UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context)
: Type(Kind::MemRef, context, shape.size()), elementType(elementType),
shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
ArrayRef<int> MemRefType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
Type MemRefType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
return static_cast<ImplType *>(type)->getAffineMaps();
}
unsigned MemRefType::getMemorySpace() const {
return static_cast<ImplType *>(type)->memorySpace;
}
unsigned MemRefType::getNumDynamicDims() const {

View File

@ -182,19 +182,19 @@ public:
// as the results of their action.
// Type parsing.
VectorType *parseVectorType();
VectorType parseVectorType();
ParseResult parseXInDimensionList();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
Type *parseTensorType();
Type *parseMemRefType();
Type *parseFunctionType();
Type *parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
Type parseTensorType();
Type parseMemRefType();
Type parseFunctionType();
Type parseType();
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type);
FunctionType type);
Attribute parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -206,9 +206,9 @@ public:
AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference();
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
VectorOrTensorType *parseVectorOrTensorType();
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
VectorOrTensorType parseVectorOrTensorType();
private:
// The Parser is subclassed and reinstantiated. Do not add additional
@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// other-type ::= `index` | `tf_control`
///
Type *Parser::parseType() {
Type Parser::parseType() {
switch (getToken().getKind()) {
default:
return (emitError("expected type"), nullptr);
@ -368,7 +368,7 @@ Type *Parser::parseType() {
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
/// const-dimension-list ::= (integer-literal `x`)+
///
VectorType *Parser::parseVectorType() {
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type"))
@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return (emitError(typeLoc, "invalid vector element type"), nullptr);
return VectorType::get(dimensions, elementType);
@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
/// dimension-list ::= dimension-list-ranked | `*x`
///
Type *Parser::parseTensorType() {
Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
if (parseToken(Token::less, "expected '<' in tensor type"))
@ -485,7 +485,7 @@ Type *Parser::parseTensorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
@ -505,7 +505,7 @@ Type *Parser::parseTensorType() {
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
/// memory-space ::= integer-literal /* | TODO: address-space-id */
///
Type *Parser::parseMemRefType() {
Type Parser::parseMemRefType() {
consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type"))
@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
auto elementType = parseType();
if (!elementType)
return nullptr;
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
!isa<VectorType>(elementType))
if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
!elementType.isa<VectorType>())
return (emitError(typeLoc, "invalid memref element type"), nullptr);
// Parse semi-affine-map-composition.
@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() {
///
/// function-type ::= type-list-parens `->` type-list
///
Type *Parser::parseFunctionType() {
Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
SmallVector<Type *, 4> arguments, results;
SmallVector<Type, 4> arguments, results;
if (parseTypeList(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") ||
parseTypeList(results))
@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() {
///
/// type-list-no-parens ::= type (`,` type)*
///
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
/// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
namespace {
class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p, Type *eltTy)
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
TensorLiteralParser(Parser &p, Type eltTy)
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
ParseResult parse() { return parseList(shape); }
@ -676,7 +676,7 @@ private:
}
Parser &p;
Type *eltTy;
Type eltTy;
size_t currBitPos;
size_t bitsWidth;
SmallVector<int, 4> shape;
@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
if (!result)
return p.emitError("expected tensor element");
// check result matches the element type.
switch (eltTy->getKind()) {
switch (eltTy.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) {
/// synthesizing a forward reference) or emit an error and return null on
/// failure.
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type) {
FunctionType type) {
Identifier name = builder.getIdentifier(nameStr.drop_front());
// See if the function has already been defined in the module.
@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::colon, "expected ':' and function type"))
return nullptr;
auto typeLoc = getToken().getLoc();
Type *type = parseType();
Type type = parseType();
if (!type)
return nullptr;
auto *fnType = dyn_cast<FunctionType>(type);
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() {
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
auto *type = parseVectorOrTensorType();
auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
auto val = getToken().getStringValue();
@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr;
auto *type = parseVectorOrTensorType();
auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
auto *type = parseVectorOrTensorType();
auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() {
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
auto *type = parseVectorOrTensorType();
auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
case Token::l_square: {
/// Parse indices
auto *indicesEltType = builder.getIntegerType(32);
auto indicesEltType = builder.getIntegerType(32);
auto indices =
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
if (parseToken(Token::comma, "expected ','"))
return nullptr;
/// Parse values.
auto *valuesEltType = type->getElementType();
auto valuesEltType = type.getElementType();
auto values =
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
/// Sanity check.
auto *indicesType = indices.getType();
auto *valuesType = values.getType();
auto sameShape = (indicesType->getRank() == 1) ||
(type->getRank() == indicesType->getDimSize(1));
auto indicesType = indices.getType();
auto valuesType = values.getType();
auto sameShape = (indicesType.getRank() == 1) ||
(type.getRank() == indicesType.getDimSize(1));
auto sameElementNum =
indicesType->getDimSize(0) == valuesType->getDimSize(0);
indicesType.getDimSize(0) == valuesType.getDimSize(0);
if (!sameShape || !sameElementNum) {
std::string str;
llvm::raw_string_ostream s(str);
s << "expected shape ([";
interleaveComma(type->getShape(), s);
interleaveComma(type.getShape(), s);
s << "]); inferred shape of indices literal ([";
interleaveComma(indicesType->getShape(), s);
interleaveComma(indicesType.getShape(), s);
s << "]); inferred shape of values literal ([";
interleaveComma(valuesType->getShape(), s);
interleaveComma(valuesType.getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() {
nullptr);
}
default: {
if (Type *type = parseType())
if (Type type = parseType())
return builder.getTypeAttr(type);
return nullptr;
}
@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() {
///
/// This method returns a constructed dense elements attribute with the shape
/// from the parsing result.
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse())
return nullptr;
VectorOrTensorType *type;
VectorOrTensorType type;
if (isVector) {
type = builder.getVectorType(literalParser.getShape(), eltType);
} else {
@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
auto *eltTy = type->getElementType();
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
auto eltTy = type.getElementType();
TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse())
return nullptr;
if (literalParser.getShape() != type->getShape()) {
if (literalParser.getShape() != type.getShape()) {
std::string str;
llvm::raw_string_ostream s(str);
s << "inferred shape of elements literal ([";
interleaveComma(literalParser.getShape(), s);
s << "]) does not match type ([";
interleaveComma(type->getShape(), s);
interleaveComma(type.getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
/// vector-or-tensor-type ::= vector-type | tensor-type
///
/// This method also checks the type has static shape and ranked.
VectorOrTensorType *Parser::parseVectorOrTensorType() {
auto *type = dyn_cast<VectorOrTensorType>(parseType());
VectorOrTensorType Parser::parseVectorOrTensorType() {
auto type = parseType().dyn_cast<VectorOrTensorType>();
if (!type) {
return (emitError("expected elements literal has a tensor or vector type"),
nullptr);
@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() {
if (parseToken(Token::comma, "expected ','"))
return nullptr;
if (!type->hasStaticShape() || type->getRank() == -1) {
if (!type.hasStaticShape() || type.getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
@ -1834,7 +1834,7 @@ public:
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
/// Register a definition of a value with the symbol table.
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
@ -1845,11 +1845,11 @@ public:
template <typename ResultType>
ResultType parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action);
const std::function<ResultType(SSAUseInfo, Type)> &action);
SSAValue *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>(
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
[&](SSAUseInfo useInfo, Type type) -> SSAValue * {
return resolveSSAUse(useInfo, type);
});
}
@ -1880,7 +1880,7 @@ private:
/// their first reference, to allow checking for use of undefined values.
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference.
bool isForwardReferencePlaceholder(SSAValue *value) {
@ -1891,7 +1891,7 @@ private:
/// Create and remember a new placeholder for a forward reference.
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
Type *type) {
Type type) {
// Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain.
//
@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
/// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = values[useInfo.name];
// If we have already seen a value of this name, return it.
@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
/// ssa-use-and-type ::= ssa-use `:` type
template <typename ResultType>
ResultType FunctionParser::parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
const std::function<ResultType(SSAUseInfo, Type)> &action) {
SSAUseInfo useInfo;
if (parseSSAUse(useInfo) ||
parseToken(Token::colon, "expected ':' and type for SSA operand"))
return nullptr;
auto *type = parseType();
auto type = parseType();
if (!type)
return nullptr;
@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
if (valueIDs.empty())
return ParseSuccess;
SmallVector<Type *, 4> types;
SmallVector<Type, 4> types;
if (parseToken(Token::colon, "expected ':' in operand list") ||
parseTypeListNoParens(types))
return ParseFailure;
@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation(
auto type = parseType();
if (!type)
return nullptr;
auto fnType = dyn_cast<FunctionType>(type);
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
result.addTypes(fnType->getResults());
result.addTypes(fnType.getResults());
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
auto operandTypes = fnType.getInputs();
if (operandTypes.size() != operandInfos.size()) {
auto plural = "s"[operandInfos.size() == 1];
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
@ -2253,17 +2253,17 @@ public:
return parser.parseToken(Token::comma, "expected ','");
}
bool parseColonType(Type *&result) override {
bool parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType());
}
bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return true;
do {
if (auto *type = parser.parseType())
if (auto type = parser.parseType())
result.push_back(type);
else
return true;
@ -2273,7 +2273,7 @@ public:
}
/// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type *&result) override {
bool parseKeywordType(const char *keyword, Type &result) override {
if (parser.getTokenSpelling() != keyword)
return parser.emitError("expected '" + Twine(keyword) + "'");
parser.consumeToken();
@ -2396,7 +2396,7 @@ public:
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr;
@ -2410,7 +2410,7 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type *type,
bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult {
auto type = parseSSADefOrUseAndType<Type *>(
[&](SSAUseInfo useInfo, Type *type) -> Type * {
auto type = parseSSADefOrUseAndType<Type>(
[&](SSAUseInfo useInfo, Type type) -> Type {
BBArgument *arg = owner->addArgument(type);
if (addDefinition(useInfo, arg))
return nullptr;
return {};
return type;
});
return type ? ParseSuccess : ParseFailure;
@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
" symbol count must match");
// Resolve SSA uses.
Type *indexType = builder.getIndexType();
Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval)
@ -3187,9 +3187,9 @@ private:
ParseResult parseAffineStructureDef();
// Functions.
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames);
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames);
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
ParseResult parseExtFunc();
@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() {
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
///
ParseResult
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames) {
consumeToken(Token::l_paren);
@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
/// type-list)?
///
ParseResult
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames) {
if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
if (getToken().isNot(Token::l_paren))
return emitError("expected '(' in function signature");
SmallVector<Type *, 4> argTypes;
SmallVector<Type, 4> argTypes;
ParseResult parseResult;
if (argNames)
@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
return ParseFailure;
// Parse the return type if present.
SmallVector<Type *, 4> results;
SmallVector<Type, 4> results;
if (consumeIf(Token::arrow)) {
if (parseTypeList(results))
return ParseFailure;
@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() {
auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() {
auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() {
consumeToken(Token::kw_mlfunc);
StringRef name;
FunctionType *type = nullptr;
FunctionType type;
SmallVector<StringRef, 4> argNames;
auto loc = getToken().getLoc();

View File

@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result,
MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
MemRefType memrefType, ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->types.push_back(memrefType);
}
void AllocOp::print(OpAsmPrinter *p) const {
MemRefType *type = getType();
MemRefType type = getType();
*p << "alloc";
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p);
type.getNumDynamicDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
*p << " : " << *type;
*p << " : " << type;
}
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType *type;
MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
// Verification still checks that the total number of operands matches
// the number of symbols in the affine map, plus the number of dynamic
// dimensions in the memref.
if (numDimOperands != type->getNumDynamicDims()) {
if (numDimOperands != type.getNumDynamicDims()) {
return parser->emitError(parser->getNameLoc(),
"dimension operand count does not equal memref "
"dynamic dimension count");
@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
}
bool AllocOp::verify() const {
auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("result must be a memref");
unsigned numSymbols = 0;
if (!memRefType->getAffineMaps().empty()) {
AffineMap affineMap = memRefType->getAffineMaps()[0];
if (!memRefType.getAffineMaps().empty()) {
AffineMap affineMap = memRefType.getAffineMaps()[0];
// Store number of symbols used in affine map (used in subsequent check).
numSymbols = affineMap.getNumSymbols();
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
@ -195,10 +195,10 @@ bool AllocOp::verify() const {
// Remove when we can emit errors directly from *Type::get(...) functions.
//
// Verify that the layout affine map matches the rank of the memref.
if (affineMap.getNumDims() != memRefType->getRank())
if (affineMap.getNumDims() != memRefType.getRank())
return emitOpError("affine map dimension count must equal memref rank");
}
unsigned numDynamicDims = memRefType->getNumDynamicDims();
unsigned numDynamicDims = memRefType.getNumDynamicDims();
// Check that the total number of operands matches the number of symbols in
// the affine map, plus the number of dynamic dimensions specified in the
// memref type.
@ -208,7 +208,7 @@ bool AllocOp::verify() const {
}
// Verify that all operands are of type Index.
for (auto *operand : getOperands()) {
if (!operand->getType()->isIndex())
if (!operand->getType().isIndex())
return emitOpError("requires operands to be of type Index");
}
return false;
@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern {
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants;
newShapeConstants.reserve(memrefType->getRank());
newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands;
unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
int dimSize = memrefType->getDimSize(dim);
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
int dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it.
if (dimSize != -1) {
newShapeConstants.push_back(dimSize);
@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern {
}
// Create new memref type (which will have fewer dynamic dimensions).
auto *newMemRefType = MemRefType::get(
newShapeConstants, memrefType->getElementType(),
memrefType->getAffineMaps(), memrefType->getMemorySpace());
assert(newOperands.size() == newMemRefType->getNumDynamicDims());
auto newMemRefType = MemRefType::get(
newShapeConstants, memrefType.getElementType(),
memrefType.getAffineMaps(), memrefType.getMemorySpace());
assert(newOperands.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
auto newAlloc =
@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType()->getResults());
result->addTypes(callee->getType().getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
FunctionType *calleeType = nullptr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
parser->addTypesToList(calleeType->getResults(), result->types) ||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return true;
@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const {
p->printOperands(getOperands());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType();
*p << " : " << getCallee()->getType();
}
bool CallOp::verify() const {
@ -338,20 +338,20 @@ bool CallOp::verify() const {
return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee.
auto *fnType = fnAttr.getValue()->getType();
if (fnType->getNumInputs() != getNumOperands())
auto fnType = fnAttr.getValue()->getType();
if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
if (getOperand(i)->getType() != fnType->getInput(i))
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
if (fnType->getNumResults() != getNumResults())
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i))
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@ -364,14 +364,14 @@ bool CallOp::verify() const {
void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) {
auto *fnType = cast<FunctionType>(callee->getType());
auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee);
result->addOperands(operands);
result->addTypes(fnType->getResults());
result->addTypes(fnType.getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType *calleeType = nullptr;
FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) ||
parser->addTypesToList(calleeType->getResults(), result->types);
parser->addTypesToList(calleeType.getResults(), result->types);
}
void CallIndirectOp::print(OpAsmPrinter *p) const {
@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const {
p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType();
*p << " : " << getCallee()->getType();
}
bool CallIndirectOp::verify() const {
// The callee must be a function.
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
if (!fnType)
return emitOpError("callee must have function type");
// Verify that the operand and result types match the callee.
if (fnType->getNumInputs() != getNumOperands() - 1)
if (fnType.getNumInputs() != getNumOperands() - 1)
return emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
if (getOperand(i + 1)->getType() != fnType->getInput(i))
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i + 1)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
if (fnType->getNumResults() != getNumResults())
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i))
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result,
}
void DeallocOp::print(OpAsmPrinter *p) const {
*p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
}
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
MemRefType *type;
MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands);
}
bool DeallocOp::verify() const {
if (!isa<MemRefType>(getMemRef()->getType()))
if (!getMemRef()->getType().isa<MemRefType>())
return emitOpError("operand must be a memref");
return false;
}
@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result,
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
*p << " : " << *getOperand()->getType();
*p << " : " << getOperand()->getType();
}
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
Type *type;
Type type;
return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
@ -496,15 +496,15 @@ bool DimOp::verify() const {
return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr.getValue();
auto *type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
if (index >= tensorType->getRank())
auto type = getOperand()->getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index >= tensorType.getRank())
return emitOpError("index is out of range");
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
if (index >= memrefType->getRank())
} else if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index >= memrefType.getRank())
return emitOpError("index is out of range");
} else if (isa<UnrankedTensorType>(type)) {
} else if (type.isa<UnrankedTensorType>()) {
// ok, assumed to be in-range.
} else {
return emitOpError("requires an operand with tensor or memref type");
@ -516,12 +516,12 @@ bool DimOp::verify() const {
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant.
auto *opType = getOperand()->getType();
auto opType = getOperand()->getType();
int indexSize = -1;
if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
indexSize = tensorType->getShape()[getIndex()];
} else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
indexSize = memrefType->getShape()[getIndex()];
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
indexSize = tensorType.getShape()[getIndex()];
} else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
indexSize = memrefType.getShape()[getIndex()];
}
if (indexSize >= 0)
@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getSrcMemRef()->getType();
*p << ", " << *getDstMemRef()->getType();
*p << ", " << *getTagMemRef()->getType();
*p << " : " << getSrcMemRef()->getType();
*p << ", " << getDstMemRef()->getType();
*p << ", " << getTagMemRef()->getType();
}
// Parse DmaStartOp.
@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
SmallVector<Type *, 3> types;
auto *indexType = parser->getBuilder().getIndexType();
SmallVector<Type, 3> types;
auto indexType = parser->getBuilder().getIndexType();
// Parse and resolve the following list of operands:
// *) source memref followed by its indices (in square brackets).
@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return true;
// Check that source/destination index list size matches associated rank.
if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"memref rank not equal to indices count");
if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
p->printOperands(getTagIndices());
*p << "], ";
p->printOperand(getNumElements());
*p << " : " << *getTagMemRef()->getType();
*p << " : " << getTagMemRef()->getType();
}
// Parse DmaWaitOp.
@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type *type;
auto *indexType = parser->getBuilder().getIndexType();
Type type;
auto indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its indices, and dma size.
@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate,
ArrayRef<SSAValue *> indices) {
auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate);
result->addOperands(indices);
result->types.push_back(aggregateType->getElementType());
result->types.push_back(aggregateType.getElementType());
}
void ExtractElementOp::print(OpAsmPrinter *p) const {
@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getAggregate()->getType();
*p << " : " << getAggregate()->getType();
}
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
VectorOrTensorType *type;
VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) ||
@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type->getElementType(), result->types);
parser->addTypeToList(type.getElementType(), result->types);
}
bool ExtractElementOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected an aggregate to index into");
auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
if (!aggregateType)
return emitOpError("first operand must be a vector or tensor");
if (getType() != aggregateType->getElementType())
if (getType() != aggregateType.getElementType())
return emitOpError("result type must match element type of aggregate");
for (auto *idx : getIndices())
if (!idx->getType()->isIndex())
if (!idx->getType().isIndex())
return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
auto aggregateRank = aggregateType->getRank();
auto aggregateRank = aggregateType.getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element");
@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const {
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices) {
auto *memrefType = cast<MemRefType>(memref->getType());
auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref);
result->addOperands(indices);
result->types.push_back(memrefType->getElementType());
result->types.push_back(memrefType.getElementType());
}
void LoadOp::print(OpAsmPrinter *p) const {
@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRefType();
*p << " : " << getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) ||
@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type->getElementType(), result->types);
parser->addTypeToList(type.getElementType(), result->types);
}
bool LoadOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected a memref to load from");
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("first operand must be a memref");
if (getType() != memRefType->getElementType())
if (getType() != memRefType.getElementType())
return emitOpError("result type must match element type of memref");
if (memRefType->getRank() != getNumOperands() - 1)
if (memRefType.getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
if (!idx->getType()->isIndex())
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
bool MemRefCastOp::verify() const {
auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
auto *resType = dyn_cast<MemRefType>(getType());
auto opType = getOperand()->getType().dyn_cast<MemRefType>();
auto resType = getType().dyn_cast<MemRefType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be memrefs");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
if (opType->getElementType() != resType->getElementType())
if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
if (opType->getAffineMaps() != resType->getAffineMaps())
if (opType.getAffineMaps() != resType.getAffineMaps())
return emitOpError("requires input and result mappings to be the same");
if (opType->getMemorySpace() != resType->getMemorySpace())
if (opType.getMemorySpace() != resType.getMemorySpace())
return emitOpError(
"requires input and result memory spaces to be the same");
// They must have the same rank, and any specified dimensions must match.
if (opType->getRank() != resType->getRank())
if (opType.getRank() != resType.getRank())
return emitOpError("requires input and result ranks to match");
for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}
@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRefType();
*p << " : " << getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *memrefType;
MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@ -950,19 +950,19 @@ bool StoreOp::verify() const {
return emitOpError("expected a value to store and a memref");
// Second operand is a memref type.
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("second operand must be a memref");
// First operand must have same type as memref element type.
if (getValueToStore()->getType() != memRefType->getElementType())
if (getValueToStore()->getType() != memRefType.getElementType())
return emitOpError("first operand must have same type memref element type");
if (getNumOperands() != 2 + memRefType->getRank())
if (getNumOperands() != 2 + memRefType.getRank())
return emitOpError("store index operand count not equal to memref rank");
for (auto *idx : getIndices())
if (!idx->getType()->isIndex())
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
//===----------------------------------------------------------------------===//
bool TensorCastOp::verify() const {
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
auto *resType = dyn_cast<TensorType>(getType());
auto opType = getOperand()->getType().dyn_cast<TensorType>();
auto resType = getType().dyn_cast<TensorType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be tensors");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
if (opType->getElementType() != resType->getElementType())
if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
// If the source or destination are unranked, then the cast is valid.
auto *opRType = dyn_cast<RankedTensorType>(opType);
auto *resRType = dyn_cast<RankedTensorType>(resType);
auto opRType = opType.dyn_cast<RankedTensorType>();
auto resRType = resType.dyn_cast<RankedTensorType>();
if (!opRType || !resRType)
return false;
// If they are both ranked, they have to have the same rank, and any specified
// dimensions must match.
if (opRType->getRank() != resRType->getRank())
if (opRType.getRank() != resRType.getRank())
return emitOpError("requires input and result ranks to match");
for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}

View File

@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
auto &inst = *instIt++;
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
builder.setInsertionPoint(&inst);
return builder.create<ConstantOp>(inst.getLoc(), value, type);
};
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
};

View File

@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
// Add the leading dimension in the shape for the double buffer.
ArrayRef<int> shape = oldMemRefType->getShape();
ArrayRef<int> shape = oldMemRefType.getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
auto *newMemRefType =
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
oldMemRefType->getMemorySpace());
auto newMemRefType =
bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
oldMemRefType.getMemorySpace());
return newMemRefType;
};
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
// Create and place the alloc at the top level.
MLFuncBuilder topBuilder(forStmt->getFunction());
auto *newMemRef = cast<MLValue>(
auto newMemRef = cast<MLValue>(
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult());

View File

@ -78,7 +78,7 @@ private:
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
};
}; // end anonymous namespace

View File

@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumInputs() == oldMemRefRank);
@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
}
// Assert same elemental type.
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
cast<MemRefType>(newMemRef->getType())->getElementType());
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
newMemRef->getType().cast<MemRefType>().getElementType());
// Check if memref was used in a non-deferencing context.
for (const StmtOperand &use : oldMemRef->getUses()) {
@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
opStmt->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
SmallVector<Type *, 8> resultTypes;
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (const auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());

View File

@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches,
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
/// tmpl. The MemRef should be promoted to a closer memory address space in a
/// later pass.
static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
ArrayRef<int> vectorSizes) {
auto *elementType = tmpl->getElementType();
assert(!dyn_cast<VectorType>(elementType) &&
static MemRefType getVectorizedMemRefType(MemRefType tmpl,
ArrayRef<int> vectorSizes) {
auto elementType = tmpl.getElementType();
assert(!elementType.dyn_cast<VectorType>() &&
"Can't vectorize an already vector type");
assert(tmpl->getAffineMaps().empty() &&
assert(tmpl.getAffineMaps().empty() &&
"Unsupported non-implicit identity map");
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
tmpl->getMemorySpace());
tmpl.getMemorySpace());
}
/// Creates an unaligned load with the following semantics:
@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() {
template <typename LoadOrStoreOpPointer>
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
ArrayRef<int> vectorSize) {
auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
auto memRefType =
memoryOp->getMemRef()->getType().template cast<MemRefType>();
auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
// Materialize a MemRef with 1 vector.
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());