forked from OSchip/llvm-project
Implement value type abstraction for types.
This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast. PiperOrigin-RevId: 219372163
This commit is contained in:
parent
75376b8e33
commit
4c465a181d
|
@ -51,7 +51,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
|||
/// whether indices[dim] is independent of the value `input`.
|
||||
// For now we assume no layout map or identity layout map in the MemRef.
|
||||
// TODO(ntv): support more than identity layout map.
|
||||
bool isAccessInvariant(const MLValue &input, MemRefType *memRefType,
|
||||
bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
|
||||
llvm::ArrayRef<MLValue *> indices, unsigned dim);
|
||||
|
||||
/// Checks whether all the LoadOp and StoreOp matched have access indexing
|
||||
|
|
|
@ -250,9 +250,9 @@ public:
|
|||
TypeAttr() = default;
|
||||
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static TypeAttr get(Type *type, MLIRContext *context);
|
||||
static TypeAttr get(Type type, MLIRContext *context);
|
||||
|
||||
Type *getValue() const;
|
||||
Type getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Type; }
|
||||
|
@ -277,7 +277,7 @@ public:
|
|||
|
||||
Function *getValue() const;
|
||||
|
||||
FunctionType *getType() const;
|
||||
FunctionType getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||
|
@ -294,7 +294,7 @@ public:
|
|||
ElementsAttr() = default;
|
||||
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
VectorOrTensorType *getType() const;
|
||||
VectorOrTensorType getType() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
|
@ -313,7 +313,7 @@ public:
|
|||
SplatElementsAttr() = default;
|
||||
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
|
||||
static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
|
||||
Attribute getValue() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
|
@ -335,12 +335,12 @@ public:
|
|||
/// width specified by the element type (note all float type are 64 bits).
|
||||
/// When the value is retrieved, the bits are read from the storage and extend
|
||||
/// to 64 bits if necessary.
|
||||
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
|
||||
|
||||
// TODO: Read the data from the attribute list and compress them
|
||||
// to a character array. Then call the above method to construct the
|
||||
// attribute.
|
||||
static DenseElementsAttr get(VectorOrTensorType *type,
|
||||
static DenseElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
@ -410,7 +410,7 @@ public:
|
|||
OpaqueElementsAttr() = default;
|
||||
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
|
||||
static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
|
||||
|
||||
StringRef getValue() const;
|
||||
|
||||
|
@ -440,7 +440,7 @@ public:
|
|||
SparseElementsAttr() = default;
|
||||
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static SparseElementsAttr get(VectorOrTensorType *type,
|
||||
static SparseElementsAttr get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
|
||||
|
|
|
@ -64,10 +64,10 @@ public:
|
|||
bool args_empty() const { return arguments.empty(); }
|
||||
|
||||
/// Add one value to the operand list.
|
||||
BBArgument *addArgument(Type *type);
|
||||
BBArgument *addArgument(Type type);
|
||||
|
||||
/// Add one argument to the argument list for each type specified in the list.
|
||||
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
|
||||
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
|
||||
|
||||
unsigned getNumArguments() const { return arguments.size(); }
|
||||
BBArgument *getArgument(unsigned i) { return arguments[i]; }
|
||||
|
|
|
@ -68,29 +68,28 @@ public:
|
|||
unsigned column);
|
||||
|
||||
// Types.
|
||||
FloatType *getBF16Type();
|
||||
FloatType *getF16Type();
|
||||
FloatType *getF32Type();
|
||||
FloatType *getF64Type();
|
||||
FloatType getBF16Type();
|
||||
FloatType getF16Type();
|
||||
FloatType getF32Type();
|
||||
FloatType getF64Type();
|
||||
|
||||
OtherType *getIndexType();
|
||||
OtherType *getTFControlType();
|
||||
OtherType *getTFStringType();
|
||||
OtherType *getTFResourceType();
|
||||
OtherType *getTFVariantType();
|
||||
OtherType *getTFComplex64Type();
|
||||
OtherType *getTFComplex128Type();
|
||||
OtherType *getTFF32REFType();
|
||||
OtherType getIndexType();
|
||||
OtherType getTFControlType();
|
||||
OtherType getTFStringType();
|
||||
OtherType getTFResourceType();
|
||||
OtherType getTFVariantType();
|
||||
OtherType getTFComplex64Type();
|
||||
OtherType getTFComplex128Type();
|
||||
OtherType getTFF32REFType();
|
||||
|
||||
IntegerType *getIntegerType(unsigned width);
|
||||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
||||
ArrayRef<Type *> results);
|
||||
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition = {},
|
||||
unsigned memorySpace = 0);
|
||||
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
|
||||
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
||||
UnrankedTensorType *getTensorType(Type *elementType);
|
||||
IntegerType getIntegerType(unsigned width);
|
||||
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
|
||||
MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition = {},
|
||||
unsigned memorySpace = 0);
|
||||
VectorType getVectorType(ArrayRef<int> shape, Type elementType);
|
||||
RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
|
||||
UnrankedTensorType getTensorType(Type elementType);
|
||||
|
||||
// Attributes.
|
||||
|
||||
|
@ -102,15 +101,15 @@ public:
|
|||
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||
TypeAttr getTypeAttr(Type *type);
|
||||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(const Function *value);
|
||||
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
|
||||
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
|
||||
|
||||
// Affine expressions and affine maps.
|
||||
AffineExpr getAffineDimExpr(unsigned position);
|
||||
|
@ -366,7 +365,7 @@ public:
|
|||
/// Creates an operation given the fields.
|
||||
OperationStmt *createOperation(Location *location, OperationName name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> types,
|
||||
ArrayRef<Type> types,
|
||||
ArrayRef<NamedAttribute> attrs);
|
||||
|
||||
/// Create operation of specific op type at the current insertion point.
|
||||
|
|
|
@ -96,7 +96,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
|
|||
public:
|
||||
/// Builds a constant op with the specified attribute value and result type.
|
||||
static void build(Builder *builder, OperationState *result, Attribute value,
|
||||
Type *type);
|
||||
Type type);
|
||||
|
||||
Attribute getValue() const { return getAttr("value"); }
|
||||
|
||||
|
@ -123,7 +123,7 @@ class ConstantFloatOp : public ConstantOp {
|
|||
public:
|
||||
/// Builds a constant float op producing a float of the specified type.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
const APFloat &value, FloatType *type);
|
||||
const APFloat &value, FloatType type);
|
||||
|
||||
APFloat getValue() const {
|
||||
return getAttrOfType<FloatAttr>("value").getValue();
|
||||
|
@ -150,7 +150,7 @@ public:
|
|||
/// Build a constant int op producing an integer with the specified type,
|
||||
/// which must be an integer type.
|
||||
static void build(Builder *builder, OperationState *result, int64_t value,
|
||||
Type *type);
|
||||
Type type);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace mlir {
|
|||
// blocks, each of which includes instructions.
|
||||
class CFGFunction : public Function {
|
||||
public:
|
||||
CFGFunction(Location *location, StringRef name, FunctionType *type,
|
||||
CFGFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
~CFGFunction();
|
||||
|
|
|
@ -66,7 +66,7 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
||||
CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
||||
/// Basic block arguments are CFG Values.
|
||||
|
@ -87,7 +87,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class BasicBlock; // For access to private constructor.
|
||||
BBArgument(Type *type, BasicBlock *owner)
|
||||
BBArgument(Type type, BasicBlock *owner)
|
||||
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
|
@ -99,7 +99,7 @@ private:
|
|||
/// Instruction results are CFG Values.
|
||||
class InstResult : public CFGValue {
|
||||
public:
|
||||
InstResult(Type *type, OperationInst *owner)
|
||||
InstResult(Type type, OperationInst *owner)
|
||||
: CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
|
||||
|
||||
static bool classof(const SSAValue *value) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ilist.h"
|
||||
|
||||
|
@ -55,7 +56,7 @@ public:
|
|||
Identifier getName() const { return nameAndKind.getPointer(); }
|
||||
|
||||
/// Return the type of this function.
|
||||
FunctionType *getType() const { return type; }
|
||||
FunctionType getType() const { return type; }
|
||||
|
||||
/// Returns all of the attributes on this function.
|
||||
ArrayRef<NamedAttribute> getAttrs() const;
|
||||
|
@ -93,7 +94,7 @@ public:
|
|||
void emitNote(const Twine &message) const;
|
||||
|
||||
protected:
|
||||
Function(Kind kind, Location *location, StringRef name, FunctionType *type,
|
||||
Function(Kind kind, Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
~Function();
|
||||
|
||||
|
@ -108,7 +109,7 @@ private:
|
|||
Location *location;
|
||||
|
||||
/// The type of the function.
|
||||
FunctionType *const type;
|
||||
FunctionType type;
|
||||
|
||||
/// This holds general named attributes for the function.
|
||||
AttributeListStorage *attrs;
|
||||
|
@ -121,7 +122,7 @@ private:
|
|||
/// defined in some other module.
|
||||
class ExtFunction : public Function {
|
||||
public:
|
||||
ExtFunction(Location *location, StringRef name, FunctionType *type,
|
||||
ExtFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
|
|
|
@ -202,7 +202,7 @@ public:
|
|||
/// Create a new OperationInst with the specified fields.
|
||||
static OperationInst *create(Location *location, OperationName name,
|
||||
ArrayRef<CFGValue *> operands,
|
||||
ArrayRef<Type *> resultTypes,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
MLIRContext *context);
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class MLFunction final
|
|||
public:
|
||||
/// Creates a new MLFunction with the specific type.
|
||||
static MLFunction *create(Location *location, StringRef name,
|
||||
FunctionType *type,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
|
@ -52,7 +52,7 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Returns number of arguments.
|
||||
unsigned getNumArguments() const { return getType()->getInputs().size(); }
|
||||
unsigned getNumArguments() const { return getType().getInputs().size(); }
|
||||
|
||||
/// Gets argument.
|
||||
MLFuncArgument *getArgument(unsigned idx) {
|
||||
|
@ -103,13 +103,13 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
MLFunction(Location *location, StringRef name, FunctionType *type,
|
||||
MLFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
// This stuff is used by the TrailingObjects template.
|
||||
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
|
||||
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
|
||||
return getType()->getInputs().size();
|
||||
return getType().getInputs().size();
|
||||
}
|
||||
|
||||
// Internal functions to get argument list used by getArgument() methods.
|
||||
|
|
|
@ -73,7 +73,7 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
||||
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
||||
/// This is the value defined by an argument of an ML function.
|
||||
|
@ -93,7 +93,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class MLFunction; // For access to private constructor.
|
||||
MLFuncArgument(Type *type, MLFunction *owner)
|
||||
MLFuncArgument(Type type, MLFunction *owner)
|
||||
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
|
@ -105,7 +105,7 @@ private:
|
|||
/// This is a value defined by a result of an operation instruction.
|
||||
class StmtResult : public MLValue {
|
||||
public:
|
||||
StmtResult(Type *type, OperationStmt *owner)
|
||||
StmtResult(Type type, OperationStmt *owner)
|
||||
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
|
||||
|
||||
static bool classof(const SSAValue *value) {
|
||||
|
|
|
@ -71,13 +71,13 @@ struct constant_int_op_binder {
|
|||
|
||||
bool match(Operation *op) {
|
||||
if (auto constOp = op->dyn_cast<ConstantOp>()) {
|
||||
auto *type = constOp->getResult()->getType();
|
||||
auto type = constOp->getResult()->getType();
|
||||
auto attr = constOp->getAttr("value");
|
||||
|
||||
if (isa<IntegerType>(type)) {
|
||||
if (type.isa<IntegerType>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
||||
}
|
||||
if (isa<VectorOrTensorType>(type)) {
|
||||
if (type.isa<VectorOrTensorType>()) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value)
|
||||
.match(splatAttr.getValue());
|
||||
|
|
|
@ -493,7 +493,7 @@ public:
|
|||
return this->getOperation()->getResult(0);
|
||||
}
|
||||
|
||||
Type *getType() const { return getResult()->getType(); }
|
||||
Type getType() const { return getResult()->getType(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
|
@ -539,7 +539,7 @@ public:
|
|||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifyNResults(op, N);
|
||||
|
@ -565,7 +565,7 @@ public:
|
|||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifyAtLeastNResults(op, N);
|
||||
|
@ -803,7 +803,7 @@ protected:
|
|||
// which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
|
||||
Type *destType);
|
||||
Type destType);
|
||||
bool parseCastOp(OpAsmParser *parser, OperationState *result);
|
||||
void printCastOp(const Operation *op, OpAsmPrinter *p);
|
||||
} // namespace impl
|
||||
|
@ -819,7 +819,7 @@ class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
|
|||
OpTrait::HasNoSideEffect, Traits...> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *source,
|
||||
Type *destType) {
|
||||
Type destType) {
|
||||
impl::buildCastOp(builder, result, source, destType);
|
||||
}
|
||||
static bool parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
|
|
@ -67,7 +67,7 @@ public:
|
|||
printOperand(*it);
|
||||
}
|
||||
}
|
||||
virtual void printType(const Type *type) = 0;
|
||||
virtual void printType(Type type) = 0;
|
||||
virtual void printFunctionReference(const Function *func) = 0;
|
||||
virtual void printAttribute(Attribute attr) = 0;
|
||||
virtual void printAffineMap(AffineMap map) = 0;
|
||||
|
@ -95,8 +95,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
|
|||
return p;
|
||||
}
|
||||
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
|
||||
p.printType(&type);
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
|
||||
p.printType(type);
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -163,20 +163,20 @@ public:
|
|||
virtual bool parseComma() = 0;
|
||||
|
||||
/// Parse a colon followed by a type.
|
||||
virtual bool parseColonType(Type *&result) = 0;
|
||||
virtual bool parseColonType(Type &result) = 0;
|
||||
|
||||
/// Parse a type of a specific kind, e.g. a FunctionType.
|
||||
template <typename TypeType> bool parseColonType(TypeType *&result) {
|
||||
template <typename TypeType> bool parseColonType(TypeType &result) {
|
||||
llvm::SMLoc loc;
|
||||
getCurrentLocation(&loc);
|
||||
|
||||
// Parse any kind of type.
|
||||
Type *type;
|
||||
Type type;
|
||||
if (parseColonType(type))
|
||||
return true;
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
result = dyn_cast<TypeType>(type);
|
||||
result = type.dyn_cast<TypeType>();
|
||||
if (!result) {
|
||||
emitError(loc, "invalid kind of type specified");
|
||||
return true;
|
||||
|
@ -186,15 +186,15 @@ public:
|
|||
}
|
||||
|
||||
/// Parse a colon followed by a type list, which must have at least one type.
|
||||
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0;
|
||||
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||
|
||||
/// Parse a keyword followed by a type.
|
||||
virtual bool parseKeywordType(const char *keyword, Type *&result) = 0;
|
||||
virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
|
||||
|
||||
/// Add the specified type to the end of the specified type list and return
|
||||
/// false. This is a helper designed to allow parse methods to be simple and
|
||||
/// chain through || operators.
|
||||
bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) {
|
||||
bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
|
||||
result.push_back(type);
|
||||
return false;
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ public:
|
|||
/// Add the specified types to the end of the specified type list and return
|
||||
/// false. This is a helper designed to allow parse methods to be simple and
|
||||
/// chain through || operators.
|
||||
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
|
||||
bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
|
||||
result.append(types.begin(), types.end());
|
||||
return false;
|
||||
}
|
||||
|
@ -288,13 +288,13 @@ public:
|
|||
|
||||
/// Resolve an operand to an SSA value, emitting an error and returning true
|
||||
/// on failure.
|
||||
virtual bool resolveOperand(const OperandType &operand, Type *type,
|
||||
virtual bool resolveOperand(const OperandType &operand, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) = 0;
|
||||
|
||||
/// Resolve a list of operands to SSA values, emitting an error and returning
|
||||
/// true on failure, or appending the results to the list on success.
|
||||
/// This method should be used when all operands have the same type.
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type,
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) {
|
||||
for (auto elt : operands)
|
||||
if (resolveOperand(elt, type, result))
|
||||
|
@ -306,7 +306,7 @@ public:
|
|||
/// emitting an error and returning true on failure, or appending the results
|
||||
/// to the list on success.
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands,
|
||||
ArrayRef<Type *> types, llvm::SMLoc loc,
|
||||
ArrayRef<Type> types, llvm::SMLoc loc,
|
||||
SmallVectorImpl<SSAValue *> &result) {
|
||||
if (operands.size() != types.size())
|
||||
return emitError(loc, Twine(operands.size()) +
|
||||
|
@ -321,7 +321,7 @@ public:
|
|||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType type,
|
||||
llvm::SMLoc loc, Function *&result) = 0;
|
||||
|
||||
/// Emit a diagnostic at the specified location and return true.
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
#include <memory>
|
||||
|
||||
|
@ -191,7 +192,7 @@ struct OperationState {
|
|||
OperationName name;
|
||||
SmallVector<SSAValue *, 4> operands;
|
||||
/// Types of the results of this operation.
|
||||
SmallVector<Type *, 4> types;
|
||||
SmallVector<Type, 4> types;
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
|
||||
public:
|
||||
|
@ -202,7 +203,7 @@ public:
|
|||
: context(context), location(location), name(name) {}
|
||||
|
||||
OperationState(MLIRContext *context, Location *location, StringRef name,
|
||||
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
|
||||
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
|
||||
ArrayRef<NamedAttribute> attributes = {})
|
||||
: context(context), location(location), name(name, context),
|
||||
operands(operands.begin(), operands.end()),
|
||||
|
@ -213,7 +214,7 @@ public:
|
|||
operands.append(newOperands.begin(), newOperands.end());
|
||||
}
|
||||
|
||||
void addTypes(ArrayRef<Type *> newTypes) {
|
||||
void addTypes(ArrayRef<Type> newTypes) {
|
||||
types.append(newTypes.begin(), newTypes.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/UseDefLists.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/PointerIntPair.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
|
@ -51,7 +50,7 @@ public:
|
|||
|
||||
SSAValueKind getKind() const { return typeAndKind.getInt(); }
|
||||
|
||||
Type *getType() const { return typeAndKind.getPointer(); }
|
||||
Type getType() const { return typeAndKind.getPointer(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
|
@ -93,9 +92,10 @@ public:
|
|||
void dump() const;
|
||||
|
||||
protected:
|
||||
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
|
||||
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
|
||||
|
||||
private:
|
||||
const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind;
|
||||
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
|
||||
|
@ -127,7 +127,7 @@ public:
|
|||
inline use_range getUses() const;
|
||||
|
||||
protected:
|
||||
SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {}
|
||||
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
|
||||
};
|
||||
|
||||
// Utility functions for iterating through SSAValue uses.
|
||||
|
|
|
@ -44,7 +44,7 @@ public:
|
|||
/// Create a new OperationStmt with the specific fields.
|
||||
static OperationStmt *create(Location *location, OperationName name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> resultTypes,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
MLIRContext *context);
|
||||
|
||||
|
@ -329,7 +329,7 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const { return getType()->getContext(); }
|
||||
MLIRContext *getContext() const { return getType().getContext(); }
|
||||
|
||||
using Statement::dump;
|
||||
using Statement::print;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
|
@ -28,6 +29,22 @@ class IntegerType;
|
|||
class FloatType;
|
||||
class OtherType;
|
||||
|
||||
namespace detail {
|
||||
|
||||
class TypeStorage;
|
||||
class IntegerTypeStorage;
|
||||
class FloatTypeStorage;
|
||||
struct OtherTypeStorage;
|
||||
struct FunctionTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||
/// MLIRContext. As such, they are passed around by raw non-const pointer.
|
||||
///
|
||||
|
@ -68,11 +85,34 @@ public:
|
|||
MemRef,
|
||||
};
|
||||
|
||||
using ImplType = detail::TypeStorage;
|
||||
|
||||
Type() : type(nullptr) {}
|
||||
/* implicit */ Type(const ImplType *type)
|
||||
: type(const_cast<ImplType *>(type)) {}
|
||||
|
||||
Type(const Type &other) : type(other.type) {}
|
||||
Type &operator=(Type other) {
|
||||
type = other.type;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(Type other) const { return type == other.type; }
|
||||
bool operator!=(Type other) const { return !(*this == other); }
|
||||
explicit operator bool() const { return type; }
|
||||
|
||||
bool operator!() const { return type == nullptr; }
|
||||
|
||||
template <typename U> bool isa() const;
|
||||
template <typename U> U dyn_cast() const;
|
||||
template <typename U> U dyn_cast_or_null() const;
|
||||
template <typename U> U cast() const;
|
||||
|
||||
/// Return the classification for this type.
|
||||
Kind getKind() const { return kind; }
|
||||
Kind getKind() const;
|
||||
|
||||
/// Return the LLVMContext in which this type was uniqued.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
// Convenience predicates. This is only for 'other' and floating point types,
|
||||
// derived types should use isa/dyn_cast.
|
||||
|
@ -97,56 +137,42 @@ public:
|
|||
unsigned getBitWidth() const;
|
||||
|
||||
// Convenience factories.
|
||||
static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
|
||||
static FloatType *getBF16(MLIRContext *ctx);
|
||||
static FloatType *getF16(MLIRContext *ctx);
|
||||
static FloatType *getF32(MLIRContext *ctx);
|
||||
static FloatType *getF64(MLIRContext *ctx);
|
||||
static OtherType *getIndex(MLIRContext *ctx);
|
||||
static OtherType *getTFControl(MLIRContext *ctx);
|
||||
static OtherType *getTFString(MLIRContext *ctx);
|
||||
static OtherType *getTFResource(MLIRContext *ctx);
|
||||
static OtherType *getTFVariant(MLIRContext *ctx);
|
||||
static OtherType *getTFComplex64(MLIRContext *ctx);
|
||||
static OtherType *getTFComplex128(MLIRContext *ctx);
|
||||
static OtherType *getTFF32REF(MLIRContext *ctx);
|
||||
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
|
||||
static FloatType getBF16(MLIRContext *ctx);
|
||||
static FloatType getF16(MLIRContext *ctx);
|
||||
static FloatType getF32(MLIRContext *ctx);
|
||||
static FloatType getF64(MLIRContext *ctx);
|
||||
static OtherType getIndex(MLIRContext *ctx);
|
||||
static OtherType getTFControl(MLIRContext *ctx);
|
||||
static OtherType getTFString(MLIRContext *ctx);
|
||||
static OtherType getTFResource(MLIRContext *ctx);
|
||||
static OtherType getTFVariant(MLIRContext *ctx);
|
||||
static OtherType getTFComplex64(MLIRContext *ctx);
|
||||
static OtherType getTFComplex128(MLIRContext *ctx);
|
||||
static OtherType getTFF32REF(MLIRContext *ctx);
|
||||
|
||||
/// Print the current type.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(Type arg);
|
||||
|
||||
unsigned getSubclassData() const;
|
||||
void setSubclassData(unsigned val);
|
||||
|
||||
/// Methods for supporting PointerLikeTypeTraits.
|
||||
const void *getAsOpaquePointer() const {
|
||||
return static_cast<const void *>(type);
|
||||
}
|
||||
static Type getFromOpaquePointer(const void *pointer) {
|
||||
return Type((ImplType *)(pointer));
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Type(Kind kind, MLIRContext *context)
|
||||
: context(context), kind(kind), subclassData(0) {}
|
||||
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
|
||||
: Type(kind, context) {
|
||||
setSubclassData(subClassData);
|
||||
}
|
||||
|
||||
~Type() {}
|
||||
|
||||
unsigned getSubclassData() const { return subclassData; }
|
||||
|
||||
void setSubclassData(unsigned val) {
|
||||
subclassData = val;
|
||||
// Ensure we don't have any accidental truncation.
|
||||
assert(getSubclassData() == val && "Subclass data too large for field");
|
||||
}
|
||||
|
||||
private:
|
||||
Type(const Type&) = delete;
|
||||
void operator=(const Type&) = delete;
|
||||
/// This refers to the MLIRContext in which this type was uniqued.
|
||||
MLIRContext *const context;
|
||||
|
||||
/// Classification of the subclass, used for type checking.
|
||||
Kind kind : 8;
|
||||
|
||||
// Space for subclasses to store data.
|
||||
unsigned subclassData : 24;
|
||||
ImplType *type;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
|
||||
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
|
||||
type.print(os);
|
||||
return os;
|
||||
}
|
||||
|
@ -154,148 +180,138 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
|
|||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
class IntegerType : public Type {
|
||||
public:
|
||||
static IntegerType *get(unsigned width, MLIRContext *context);
|
||||
using ImplType = detail::IntegerTypeStorage;
|
||||
IntegerType() = default;
|
||||
/* implicit */ IntegerType(Type::ImplType *ptr);
|
||||
|
||||
static IntegerType get(unsigned width, MLIRContext *context);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const {
|
||||
return width;
|
||||
}
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::Integer;
|
||||
}
|
||||
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||
|
||||
/// Integer representation maximal bitwidth.
|
||||
static constexpr unsigned kMaxWidth = 4096;
|
||||
private:
|
||||
unsigned width;
|
||||
IntegerType(unsigned width, MLIRContext *context);
|
||||
~IntegerType() = delete;
|
||||
};
|
||||
|
||||
inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||
return IntegerType::get(width, ctx);
|
||||
}
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
inline bool Type::isInteger(unsigned width) const {
|
||||
if (auto *intTy = dyn_cast<IntegerType>(this))
|
||||
return intTy->getWidth() == width;
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
return intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
class FloatType : public Type {
|
||||
public:
|
||||
using ImplType = detail::FloatTypeStorage;
|
||||
FloatType() = default;
|
||||
/* implicit */ FloatType(Type::ImplType *ptr);
|
||||
|
||||
static FloatType get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE;
|
||||
static bool kindof(Kind kind) {
|
||||
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= Kind::LAST_FLOATING_POINT_TYPE;
|
||||
}
|
||||
|
||||
static FloatType *get(Kind kind, MLIRContext *context);
|
||||
|
||||
private:
|
||||
FloatType(Kind kind, MLIRContext *context);
|
||||
~FloatType() = delete;
|
||||
};
|
||||
|
||||
inline FloatType *Type::getBF16(MLIRContext *ctx) {
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::BF16, ctx);
|
||||
}
|
||||
inline FloatType *Type::getF16(MLIRContext *ctx) {
|
||||
inline FloatType Type::getF16(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F16, ctx);
|
||||
}
|
||||
inline FloatType *Type::getF32(MLIRContext *ctx) {
|
||||
inline FloatType Type::getF32(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F32, ctx);
|
||||
}
|
||||
inline FloatType *Type::getF64(MLIRContext *ctx) {
|
||||
inline FloatType Type::getF64(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F64, ctx);
|
||||
}
|
||||
|
||||
/// This is a type for the random collection of special base types.
|
||||
class OtherType : public Type {
|
||||
public:
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() >= Kind::FIRST_OTHER_TYPE &&
|
||||
type->getKind() <= Kind::LAST_OTHER_TYPE;
|
||||
}
|
||||
static OtherType *get(Kind kind, MLIRContext *context);
|
||||
using ImplType = detail::OtherTypeStorage;
|
||||
OtherType() = default;
|
||||
/* implicit */ OtherType(Type::ImplType *ptr);
|
||||
|
||||
private:
|
||||
OtherType(Kind kind, MLIRContext *context);
|
||||
~OtherType() = delete;
|
||||
static OtherType get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
|
||||
}
|
||||
};
|
||||
|
||||
inline OtherType *Type::getIndex(MLIRContext *ctx) {
|
||||
inline OtherType Type::getIndex(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::Index, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFControl(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFControl, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFResource(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFResource(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFResource, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFString(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFString, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFVariant(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFVariant, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFComplex64(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFComplex64, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFComplex128(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFComplex128, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFF32REF(MLIRContext *ctx) {
|
||||
inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFF32REF, ctx);
|
||||
}
|
||||
|
||||
/// Function types map from a list of inputs to a list of results.
|
||||
class FunctionType : public Type {
|
||||
public:
|
||||
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
|
||||
MLIRContext *context);
|
||||
using ImplType = detail::FunctionTypeStorage;
|
||||
FunctionType() = default;
|
||||
/* implicit */ FunctionType(Type::ImplType *ptr);
|
||||
|
||||
static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
|
||||
MLIRContext *context);
|
||||
|
||||
// Input types.
|
||||
unsigned getNumInputs() const { return getSubclassData(); }
|
||||
|
||||
Type *getInput(unsigned i) const { return getInputs()[i]; }
|
||||
Type getInput(unsigned i) const { return getInputs()[i]; }
|
||||
|
||||
ArrayRef<Type*> getInputs() const {
|
||||
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
|
||||
}
|
||||
ArrayRef<Type> getInputs() const;
|
||||
|
||||
// Result types.
|
||||
unsigned getNumResults() const { return numResults; }
|
||||
unsigned getNumResults() const;
|
||||
|
||||
Type *getResult(unsigned i) const { return getResults()[i]; }
|
||||
Type getResult(unsigned i) const { return getResults()[i]; }
|
||||
|
||||
ArrayRef<Type*> getResults() const {
|
||||
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
|
||||
}
|
||||
ArrayRef<Type> getResults() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::Function;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned numResults;
|
||||
Type *const *inputsAndResults;
|
||||
|
||||
FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
||||
unsigned numResults, MLIRContext *context);
|
||||
~FunctionType() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||
};
|
||||
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
/// types, because many operations work on values of these aggregate types.
|
||||
class VectorOrTensorType : public Type {
|
||||
public:
|
||||
Type *getElementType() const { return elementType; }
|
||||
using ImplType = detail::VectorOrTensorTypeStorage;
|
||||
VectorOrTensorType() = default;
|
||||
/* implicit */ VectorOrTensorType(Type::ImplType *ptr);
|
||||
|
||||
Type getElementType() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the number of elements. If
|
||||
/// it is an unranked tensor or vector, abort.
|
||||
|
@ -319,56 +335,40 @@ public:
|
|||
int getDimSize(unsigned i) const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::Vector ||
|
||||
type->getKind() == Kind::RankedTensor ||
|
||||
type->getKind() == Kind::UnrankedTensor;
|
||||
static bool kindof(Kind kind) {
|
||||
return kind == Kind::Vector || kind == Kind::RankedTensor ||
|
||||
kind == Kind::UnrankedTensor;
|
||||
}
|
||||
|
||||
public:
|
||||
Type *elementType;
|
||||
|
||||
VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
|
||||
unsigned subClassData = 0);
|
||||
};
|
||||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public VectorOrTensorType {
|
||||
public:
|
||||
static VectorType *get(ArrayRef<int> shape, Type *elementType);
|
||||
using ImplType = detail::VectorTypeStorage;
|
||||
VectorType() = default;
|
||||
/* implicit */ VectorType(Type::ImplType *ptr);
|
||||
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
static VectorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::Vector;
|
||||
}
|
||||
|
||||
private:
|
||||
const int *shapeElements;
|
||||
Type *elementType;
|
||||
|
||||
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
|
||||
~VectorType() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Vector; }
|
||||
};
|
||||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
/// RankedTensorType and UnrankedTensorType.
|
||||
class TensorType : public VectorOrTensorType {
|
||||
public:
|
||||
using ImplType = detail::TensorTypeStorage;
|
||||
TensorType() = default;
|
||||
/* implicit */ TensorType(Type::ImplType *ptr);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::RankedTensor ||
|
||||
type->getKind() == Kind::UnrankedTensor;
|
||||
static bool kindof(Kind kind) {
|
||||
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
|
||||
}
|
||||
|
||||
protected:
|
||||
TensorType(Kind kind, Type *elementType, MLIRContext *context);
|
||||
~TensorType() {}
|
||||
};
|
||||
|
||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||
|
@ -376,40 +376,30 @@ protected:
|
|||
/// integer or unknown (represented -1).
|
||||
class RankedTensorType : public TensorType {
|
||||
public:
|
||||
static RankedTensorType *get(ArrayRef<int> shape,
|
||||
Type *elementType);
|
||||
using ImplType = detail::RankedTensorTypeStorage;
|
||||
RankedTensorType() = default;
|
||||
/* implicit */ RankedTensorType(Type::ImplType *ptr);
|
||||
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::RankedTensor;
|
||||
}
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
private:
|
||||
const int *shapeElements;
|
||||
|
||||
RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context);
|
||||
~RankedTensorType() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
|
||||
};
|
||||
|
||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||
/// unknown shape.
|
||||
class UnrankedTensorType : public TensorType {
|
||||
public:
|
||||
static UnrankedTensorType *get(Type *elementType);
|
||||
using ImplType = detail::UnrankedTensorTypeStorage;
|
||||
UnrankedTensorType() = default;
|
||||
/* implicit */ UnrankedTensorType(Type::ImplType *ptr);
|
||||
|
||||
static UnrankedTensorType get(Type elementType);
|
||||
|
||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::UnrankedTensor;
|
||||
}
|
||||
|
||||
private:
|
||||
UnrankedTensorType(Type *elementType, MLIRContext *context);
|
||||
~UnrankedTensorType() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
|
||||
};
|
||||
|
||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||
|
@ -418,62 +408,96 @@ private:
|
|||
/// affine map composition, represented as an array AffineMap pointers.
|
||||
class MemRefType : public Type {
|
||||
public:
|
||||
using ImplType = detail::MemRefTypeStorage;
|
||||
MemRefType() = default;
|
||||
/* implicit */ MemRefType(Type::ImplType *ptr);
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space.
|
||||
static MemRefType *get(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace);
|
||||
static MemRefType get(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace);
|
||||
|
||||
unsigned getRank() const { return getShape().size(); }
|
||||
|
||||
/// Returns an array of memref shape dimension sizes.
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
/// Returns the elemental type for this memref shape.
|
||||
Type *getElementType() const { return elementType; }
|
||||
Type getElementType() const;
|
||||
|
||||
/// Returns an array of affine map pointers representing the memref affine
|
||||
/// map composition.
|
||||
ArrayRef<AffineMap> getAffineMaps() const;
|
||||
|
||||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const { return memorySpace; }
|
||||
unsigned getMemorySpace() const;
|
||||
|
||||
/// Returns the number of dimensions with dynamic size.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::MemRef;
|
||||
}
|
||||
|
||||
private:
|
||||
/// The type of each scalar element of the memref.
|
||||
Type *elementType;
|
||||
/// An array of integers which stores the shape dimension sizes.
|
||||
const int *shapeElements;
|
||||
/// The number of affine maps in the 'affineMapList' array.
|
||||
const unsigned numAffineMaps;
|
||||
/// List of affine maps in the memref's layout/index map composition.
|
||||
AffineMap const *affineMapList;
|
||||
/// Memory space in which data referenced by memref resides.
|
||||
const unsigned memorySpace;
|
||||
|
||||
MemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
||||
MLIRContext *context);
|
||||
~MemRefType() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::MemRef; }
|
||||
};
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidTensorElementType(Type *type) {
|
||||
return isa<FloatType>(type) || isa<VectorType>(type) ||
|
||||
isa<IntegerType>(type) || isa<OtherType>(type);
|
||||
// Make Type hashable.
|
||||
inline ::llvm::hash_code hash_value(Type arg) {
|
||||
return ::llvm::hash_value(arg.type);
|
||||
}
|
||||
|
||||
template <typename U> bool Type::isa() const {
|
||||
assert(type && "isa<> used on a null type.");
|
||||
return U::kindof(getKind());
|
||||
}
|
||||
template <typename U> U Type::dyn_cast() const {
|
||||
return isa<U>() ? U(type) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Type::dyn_cast_or_null() const {
|
||||
return (type && isa<U>()) ? U(type) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Type::cast() const {
|
||||
assert(isa<U>());
|
||||
return U(type);
|
||||
}
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidTensorElementType(Type type) {
|
||||
return type.isa<FloatType>() || type.isa<VectorType>() ||
|
||||
type.isa<IntegerType>() || type.isa<OtherType>();
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// Type hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Type> {
|
||||
static mlir::Type getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::Type getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
|
||||
static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
|
||||
};
|
||||
|
||||
/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Type> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Type I) {
|
||||
return const_cast<void *>(I.getAsOpaquePointer());
|
||||
}
|
||||
static inline mlir::Type getFromVoidPointer(void *P) {
|
||||
return mlir::Type::getFromOpaquePointer(P);
|
||||
}
|
||||
enum { NumLowBitsAvailable = 3 };
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_IR_TYPES_H
|
||||
|
|
|
@ -104,15 +104,15 @@ class AllocOp
|
|||
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
||||
public:
|
||||
/// The result of an alloc is always a MemRefType.
|
||||
MemRefType *getType() const {
|
||||
return cast<MemRefType>(getResult()->getType());
|
||||
MemRefType getType() const {
|
||||
return getResult()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "alloc"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
MemRefType *memrefType, ArrayRef<SSAValue *> operands = {});
|
||||
MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
|
||||
bool verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
@ -276,7 +276,7 @@ public:
|
|||
const SSAValue *getSrcMemRef() const { return getOperand(0); }
|
||||
// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() const {
|
||||
return cast<MemRefType>(getSrcMemRef()->getType())->getRank();
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
// Returns the source memerf indices for this DMA operation.
|
||||
llvm::iterator_range<Operation::const_operand_iterator>
|
||||
|
@ -291,13 +291,13 @@ public:
|
|||
}
|
||||
// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() const {
|
||||
return cast<MemRefType>(getDstMemRef()->getType())->getRank();
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
unsigned getSrcMemorySpace() const {
|
||||
return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace();
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
unsigned getDstMemorySpace() const {
|
||||
return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace();
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
// Returns the destination memref indices for this DMA operation.
|
||||
|
@ -387,7 +387,7 @@ public:
|
|||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() const {
|
||||
return cast<MemRefType>(getTagMemRef()->getType())->getRank();
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
|
@ -460,8 +460,8 @@ public:
|
|||
SSAValue *getMemRef() { return getOperand(0); }
|
||||
const SSAValue *getMemRef() const { return getOperand(0); }
|
||||
void setMemRef(SSAValue *value) { setOperand(0, value); }
|
||||
MemRefType *getMemRefType() const {
|
||||
return cast<MemRefType>(getMemRef()->getType());
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
||||
|
@ -508,8 +508,8 @@ public:
|
|||
static StringRef getOperationName() { return "memref_cast"; }
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
MemRefType *getType() const {
|
||||
return cast<MemRefType>(getResult()->getType());
|
||||
MemRefType getType() const {
|
||||
return getResult()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
bool verify() const;
|
||||
|
@ -583,8 +583,8 @@ public:
|
|||
SSAValue *getMemRef() { return getOperand(1); }
|
||||
const SSAValue *getMemRef() const { return getOperand(1); }
|
||||
void setMemRef(SSAValue *value) { setOperand(1, value); }
|
||||
MemRefType *getMemRefType() const {
|
||||
return cast<MemRefType>(getMemRef()->getType());
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
||||
|
@ -671,8 +671,8 @@ public:
|
|||
static StringRef getOperationName() { return "tensor_cast"; }
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType *getType() const {
|
||||
return cast<TensorType>(getResult()->getType());
|
||||
TensorType getType() const {
|
||||
return getResult()->getType().cast<TensorType>();
|
||||
}
|
||||
|
||||
bool verify() const;
|
||||
|
|
|
@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
|||
return tripCountExpr.getLargestKnownDivisor();
|
||||
}
|
||||
|
||||
bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
|
||||
bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
|
||||
ArrayRef<MLValue *> indices, unsigned dim) {
|
||||
assert(indices.size() == memRefType->getRank());
|
||||
assert(indices.size() == memRefType.getRank());
|
||||
assert(dim < indices.size());
|
||||
auto layoutMap = memRefType->getAffineMaps();
|
||||
assert(memRefType->getAffineMaps().size() <= 1);
|
||||
auto layoutMap = memRefType.getAffineMaps();
|
||||
assert(memRefType.getAffineMaps().size() <= 1);
|
||||
// TODO(ntv): remove dependency on Builder once we support non-identity
|
||||
// layout map.
|
||||
Builder b(memRefType->getContext());
|
||||
Builder b(memRefType.getContext());
|
||||
assert(layoutMap.empty() ||
|
||||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
|
||||
(void)layoutMap;
|
||||
|
@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input,
|
|||
using namespace functional;
|
||||
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
|
||||
memoryOp->getIndices());
|
||||
auto *memRefType = memoryOp->getMemRefType();
|
||||
auto memRefType = memoryOp->getMemRefType();
|
||||
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
|
||||
if (fastestVaryingDim == (numIndices - 1) - d) {
|
||||
continue;
|
||||
|
@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input,
|
|||
|
||||
template <typename LoadOrStoreOpPointer>
|
||||
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
|
||||
auto *memRefType = memoryOp->getMemRefType();
|
||||
return isa<VectorType>(memRefType->getElementType());
|
||||
auto memRefType = memoryOp->getMemRefType();
|
||||
return memRefType.getElementType().template isa<VectorType>();
|
||||
}
|
||||
|
||||
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
|
||||
|
|
|
@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() {
|
|||
|
||||
// Verify that the argument list of the function and the arg list of the first
|
||||
// block line up.
|
||||
auto fnInputTypes = fn.getType()->getInputs();
|
||||
auto fnInputTypes = fn.getType().getInputs();
|
||||
if (fnInputTypes.size() != firstBB->getNumArguments())
|
||||
return failure("first block of cfgfunc must have " +
|
||||
Twine(fnInputTypes.size()) +
|
||||
|
@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands,
|
|||
|
||||
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
|
||||
// Verify that the return operands match the results of the function.
|
||||
auto results = fn.getType()->getResults();
|
||||
auto results = fn.getType().getResults();
|
||||
if (inst.getNumOperands() != results.size())
|
||||
return failure("return has " + Twine(inst.getNumOperands()) +
|
||||
" operands, but enclosing function returns " +
|
||||
|
|
|
@ -122,7 +122,7 @@ private:
|
|||
void visitForStmt(const ForStmt *forStmt);
|
||||
void visitIfStmt(const IfStmt *ifStmt);
|
||||
void visitOperationStmt(const OperationStmt *opStmt);
|
||||
void visitType(const Type *type);
|
||||
void visitType(Type type);
|
||||
void visitAttribute(Attribute attr);
|
||||
void visitOperation(const Operation *op);
|
||||
|
||||
|
@ -135,16 +135,16 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
// TODO Support visiting other types/instructions when implemented.
|
||||
void ModuleState::visitType(const Type *type) {
|
||||
if (auto *funcType = dyn_cast<FunctionType>(type)) {
|
||||
void ModuleState::visitType(Type type) {
|
||||
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
||||
// Visit input and result types for functions.
|
||||
for (auto *input : funcType->getInputs())
|
||||
for (auto input : funcType.getInputs())
|
||||
visitType(input);
|
||||
for (auto *result : funcType->getResults())
|
||||
for (auto result : funcType.getResults())
|
||||
visitType(result);
|
||||
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
|
||||
} else if (auto memref = type.dyn_cast<MemRefType>()) {
|
||||
// Visit affine maps in memref type.
|
||||
for (auto map : memref->getAffineMaps()) {
|
||||
for (auto map : memref.getAffineMaps()) {
|
||||
recordAffineMapReference(map);
|
||||
}
|
||||
}
|
||||
|
@ -271,7 +271,7 @@ public:
|
|||
void print(const Module *module);
|
||||
void printFunctionReference(const Function *func);
|
||||
void printAttribute(Attribute attr);
|
||||
void printType(const Type *type);
|
||||
void printType(Type type);
|
||||
void print(const Function *fn);
|
||||
void print(const ExtFunction *fn);
|
||||
void print(const CFGFunction *fn);
|
||||
|
@ -290,7 +290,7 @@ protected:
|
|||
void printFunctionAttributes(const Function *fn);
|
||||
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<const char *> elidedAttrs = {});
|
||||
void printFunctionResultType(const FunctionType *type);
|
||||
void printFunctionResultType(FunctionType type);
|
||||
void printAffineMapId(int affineMapId) const;
|
||||
void printAffineMapReference(AffineMap affineMap);
|
||||
void printIntegerSetId(int integerSetId) const;
|
||||
|
@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
|
|||
}
|
||||
|
||||
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||
auto *type = attr.getType();
|
||||
auto shape = type->getShape();
|
||||
auto rank = type->getRank();
|
||||
auto type = attr.getType();
|
||||
auto shape = type.getShape();
|
||||
auto rank = type.getRank();
|
||||
|
||||
SmallVector<Attribute, 16> elements;
|
||||
attr.getValues(elements);
|
||||
|
@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|||
os << ']';
|
||||
}
|
||||
|
||||
void ModulePrinter::printType(const Type *type) {
|
||||
switch (type->getKind()) {
|
||||
void ModulePrinter::printType(Type type) {
|
||||
switch (type.getKind()) {
|
||||
case Type::Kind::Index:
|
||||
os << "index";
|
||||
return;
|
||||
|
@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
|
|||
return;
|
||||
|
||||
case Type::Kind::Integer: {
|
||||
auto *integer = cast<IntegerType>(type);
|
||||
os << 'i' << integer->getWidth();
|
||||
auto integer = type.cast<IntegerType>();
|
||||
os << 'i' << integer.getWidth();
|
||||
return;
|
||||
}
|
||||
case Type::Kind::Function: {
|
||||
auto *func = cast<FunctionType>(type);
|
||||
auto func = type.cast<FunctionType>();
|
||||
os << '(';
|
||||
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
|
||||
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
||||
os << ") -> ";
|
||||
auto results = func->getResults();
|
||||
auto results = func.getResults();
|
||||
if (results.size() == 1)
|
||||
os << *results[0];
|
||||
os << results[0];
|
||||
else {
|
||||
os << '(';
|
||||
interleaveComma(results, [&](Type *type) { printType(type); });
|
||||
interleaveComma(results, [&](Type type) { printType(type); });
|
||||
os << ')';
|
||||
}
|
||||
return;
|
||||
}
|
||||
case Type::Kind::Vector: {
|
||||
auto *v = cast<VectorType>(type);
|
||||
auto v = type.cast<VectorType>();
|
||||
os << "vector<";
|
||||
for (auto dim : v->getShape())
|
||||
for (auto dim : v.getShape())
|
||||
os << dim << 'x';
|
||||
os << *v->getElementType() << '>';
|
||||
os << v.getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::RankedTensor: {
|
||||
auto *v = cast<RankedTensorType>(type);
|
||||
auto v = type.cast<RankedTensorType>();
|
||||
os << "tensor<";
|
||||
for (auto dim : v->getShape()) {
|
||||
for (auto dim : v.getShape()) {
|
||||
if (dim < 0)
|
||||
os << '?';
|
||||
else
|
||||
os << dim;
|
||||
os << 'x';
|
||||
}
|
||||
os << *v->getElementType() << '>';
|
||||
os << v.getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::UnrankedTensor: {
|
||||
auto *v = cast<UnrankedTensorType>(type);
|
||||
auto v = type.cast<UnrankedTensorType>();
|
||||
os << "tensor<*x";
|
||||
printType(v->getElementType());
|
||||
printType(v.getElementType());
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::MemRef: {
|
||||
auto *v = cast<MemRefType>(type);
|
||||
auto v = type.cast<MemRefType>();
|
||||
os << "memref<";
|
||||
for (auto dim : v->getShape()) {
|
||||
for (auto dim : v.getShape()) {
|
||||
if (dim < 0)
|
||||
os << '?';
|
||||
else
|
||||
os << dim;
|
||||
os << 'x';
|
||||
}
|
||||
printType(v->getElementType());
|
||||
for (auto map : v->getAffineMaps()) {
|
||||
printType(v.getElementType());
|
||||
for (auto map : v.getAffineMaps()) {
|
||||
os << ", ";
|
||||
printAffineMapReference(map);
|
||||
}
|
||||
// Only print the memory space if it is the non-default one.
|
||||
if (v->getMemorySpace())
|
||||
os << ", " << v->getMemorySpace();
|
||||
if (v.getMemorySpace())
|
||||
os << ", " << v.getMemorySpace();
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
|
@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
|
|||
// Function printing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ModulePrinter::printFunctionResultType(const FunctionType *type) {
|
||||
switch (type->getResults().size()) {
|
||||
void ModulePrinter::printFunctionResultType(FunctionType type) {
|
||||
switch (type.getResults().size()) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
os << " -> ";
|
||||
printType(type->getResults()[0]);
|
||||
printType(type.getResults()[0]);
|
||||
break;
|
||||
default:
|
||||
os << " -> (";
|
||||
interleaveComma(type->getResults(),
|
||||
[&](Type *eltType) { printType(eltType); });
|
||||
interleaveComma(type.getResults(),
|
||||
[&](Type eltType) { printType(eltType); });
|
||||
os << ')';
|
||||
break;
|
||||
}
|
||||
|
@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
|
|||
auto type = fn->getType();
|
||||
|
||||
os << "@" << fn->getName() << '(';
|
||||
interleaveComma(type->getInputs(),
|
||||
[&](Type *eltType) { printType(eltType); });
|
||||
interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
|
||||
os << ')';
|
||||
|
||||
printFunctionResultType(type);
|
||||
|
@ -937,7 +936,7 @@ public:
|
|||
|
||||
// Implement OpAsmPrinter.
|
||||
raw_ostream &getStream() const { return os; }
|
||||
void printType(const Type *type) { ModulePrinter::printType(type); }
|
||||
void printType(Type type) { ModulePrinter::printType(type); }
|
||||
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
|
||||
void printAffineMap(AffineMap map) {
|
||||
return ModulePrinter::printAffineMapReference(map);
|
||||
|
@ -974,10 +973,10 @@ protected:
|
|||
if (auto *op = value->getDefiningOperation()) {
|
||||
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
|
||||
// i1 constants get special names.
|
||||
if (intOp->getType()->isInteger(1)) {
|
||||
if (intOp->getType().isInteger(1)) {
|
||||
specialName << (intOp->getValue() ? "true" : "false");
|
||||
} else {
|
||||
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
|
||||
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
|
||||
}
|
||||
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
||||
specialName << 'c' << intOp->getValue();
|
||||
|
@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
|
|||
|
||||
void Type::print(raw_ostream &os) const {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printType(this);
|
||||
ModulePrinter(os, state).printType(*this);
|
||||
}
|
||||
|
||||
void Type::dump() const { print(llvm::errs()); }
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
|
|||
|
||||
/// An attribute representing a reference to a type.
|
||||
struct TypeAttributeStorage : public AttributeStorage {
|
||||
Type *value;
|
||||
Type value;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a function.
|
||||
|
@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage {
|
|||
|
||||
/// A base attribute representing a reference to a vector or tensor constant.
|
||||
struct ElementsAttributeStorage : public AttributeStorage {
|
||||
VectorOrTensorType *type;
|
||||
VectorOrTensorType type;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a vector or tensor constant,
|
||||
|
|
|
@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const {
|
|||
|
||||
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
Type *TypeAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
||||
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
|
@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const {
|
|||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
|
||||
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
VectorOrTensorType *ElementsAttr::getType() const {
|
||||
VectorOrTensorType ElementsAttr::getType() const {
|
||||
return static_cast<ImplType *>(attr)->type;
|
||||
}
|
||||
|
||||
|
@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
auto elementNum = getType().getNumElements();
|
||||
auto context = getType().getContext();
|
||||
values.reserve(elementNum);
|
||||
if (bitsWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
|
@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
|
|||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
auto elementNum = getType().getNumElements();
|
||||
auto context = getType().getContext();
|
||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
values.reserve(elementNum);
|
||||
|
|
|
@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() {
|
|||
// Argument list management.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BBArgument *BasicBlock::addArgument(Type *type) {
|
||||
BBArgument *BasicBlock::addArgument(Type type) {
|
||||
auto *arg = new BBArgument(type, this);
|
||||
arguments.push_back(arg);
|
||||
return arg;
|
||||
}
|
||||
|
||||
/// Add one argument to the argument list for each type specified in the list.
|
||||
auto BasicBlock::addArguments(ArrayRef<Type *> types)
|
||||
auto BasicBlock::addArguments(ArrayRef<Type> types)
|
||||
-> llvm::iterator_range<args_iterator> {
|
||||
arguments.reserve(arguments.size() + types.size());
|
||||
auto initialSize = arguments.size();
|
||||
for (auto *type : types) {
|
||||
for (auto type : types) {
|
||||
addArgument(type);
|
||||
}
|
||||
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
|
||||
|
|
|
@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename,
|
|||
// Types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
|
||||
FloatType Builder::getBF16Type() { return Type::getBF16(context); }
|
||||
|
||||
FloatType *Builder::getF16Type() { return Type::getF16(context); }
|
||||
FloatType Builder::getF16Type() { return Type::getF16(context); }
|
||||
|
||||
FloatType *Builder::getF32Type() { return Type::getF32(context); }
|
||||
FloatType Builder::getF32Type() { return Type::getF32(context); }
|
||||
|
||||
FloatType *Builder::getF64Type() { return Type::getF64(context); }
|
||||
FloatType Builder::getF64Type() { return Type::getF64(context); }
|
||||
|
||||
OtherType *Builder::getIndexType() { return Type::getIndex(context); }
|
||||
OtherType Builder::getIndexType() { return Type::getIndex(context); }
|
||||
|
||||
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||
|
||||
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||
OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||
|
||||
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||
OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||
|
||||
OtherType *Builder::getTFComplex64Type() {
|
||||
OtherType Builder::getTFComplex64Type() {
|
||||
return Type::getTFComplex64(context);
|
||||
}
|
||||
|
||||
OtherType *Builder::getTFComplex128Type() {
|
||||
OtherType Builder::getTFComplex128Type() {
|
||||
return Type::getTFComplex128(context);
|
||||
}
|
||||
|
||||
OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
||||
OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
||||
|
||||
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
||||
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
|
||||
|
||||
IntegerType *Builder::getIntegerType(unsigned width) {
|
||||
IntegerType Builder::getIntegerType(unsigned width) {
|
||||
return Type::getInteger(width, context);
|
||||
}
|
||||
|
||||
FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
|
||||
ArrayRef<Type *> results) {
|
||||
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
|
||||
ArrayRef<Type> results) {
|
||||
return FunctionType::get(inputs, results, context);
|
||||
}
|
||||
|
||||
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
||||
}
|
||||
|
||||
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
|
||||
VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
|
||||
return VectorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
|
||||
Type *elementType) {
|
||||
RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
|
||||
return RankedTensorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
UnrankedTensorType *Builder::getTensorType(Type *elementType) {
|
||||
UnrankedTensorType Builder::getTensorType(Type elementType) {
|
||||
return UnrankedTensorType::get(elementType);
|
||||
}
|
||||
|
||||
|
@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
|
|||
return IntegerSetAttr::get(set);
|
||||
}
|
||||
|
||||
TypeAttr Builder::getTypeAttr(Type *type) {
|
||||
TypeAttr Builder::getTypeAttr(Type type) {
|
||||
return TypeAttr::get(type, context);
|
||||
}
|
||||
|
||||
|
@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) {
|
|||
return FunctionAttr::get(value, context);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
|
||||
Attribute elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<char> data) {
|
||||
return DenseElementsAttr::get(type, data);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
return SparseElementsAttr::get(type, indices, values);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
|
||||
StringRef bytes) {
|
||||
return OpaqueElementsAttr::get(type, bytes);
|
||||
}
|
||||
|
@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
|||
OperationStmt *MLFuncBuilder::createOperation(Location *location,
|
||||
OperationName name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> types,
|
||||
ArrayRef<Type> types,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto *op = OperationStmt::create(location, name, operands, types, attrs,
|
||||
getContext());
|
||||
|
|
|
@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
|
|||
numDims = opInfos.size();
|
||||
|
||||
// Parse the optional symbol operands.
|
||||
auto *affineIntTy = parser->getBuilder().getIndexType();
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
if (parser->parseOperandList(opInfos, -1,
|
||||
OpAsmParser::Delimiter::OptionalSquare) ||
|
||||
parser->resolveOperands(opInfos, affineIntTy, operands))
|
||||
|
@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
|
|||
|
||||
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
auto &builder = parser->getBuilder();
|
||||
auto *affineIntTy = builder.getIndexType();
|
||||
auto affineIntTy = builder.getIndexType();
|
||||
|
||||
AffineMapAttr mapAttr;
|
||||
unsigned numDims;
|
||||
|
@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
|||
|
||||
/// Builds a constant op with the specified attribute value and result type.
|
||||
void ConstantOp::build(Builder *builder, OperationState *result,
|
||||
Attribute value, Type *type) {
|
||||
Attribute value, Type type) {
|
||||
result->addAttribute("value", value);
|
||||
result->types.push_back(type);
|
||||
}
|
||||
|
@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const {
|
|||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
||||
|
||||
if (!getValue().isa<FunctionAttr>())
|
||||
*p << " : " << *getType();
|
||||
*p << " : " << getType();
|
||||
}
|
||||
|
||||
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
Attribute valueAttr;
|
||||
Type *type;
|
||||
Type type;
|
||||
|
||||
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes))
|
||||
|
@ -208,33 +208,33 @@ bool ConstantOp::verify() const {
|
|||
if (!value)
|
||||
return emitOpError("requires a 'value' attribute");
|
||||
|
||||
auto *type = this->getType();
|
||||
if (isa<IntegerType>(type) || type->isIndex()) {
|
||||
auto type = this->getType();
|
||||
if (type.isa<IntegerType>() || type.isIndex()) {
|
||||
if (!value.isa<IntegerAttr>())
|
||||
return emitOpError(
|
||||
"requires 'value' to be an integer for an integer result type");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<FloatType>(type)) {
|
||||
if (type.isa<FloatType>()) {
|
||||
if (!value.isa<FloatAttr>())
|
||||
return emitOpError("requires 'value' to be a floating point constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<VectorOrTensorType>(type)) {
|
||||
if (type.isa<VectorOrTensorType>()) {
|
||||
if (!value.isa<ElementsAttr>())
|
||||
return emitOpError("requires 'value' to be a vector/tensor constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type->isTFString()) {
|
||||
if (type.isTFString()) {
|
||||
if (!value.isa<StringAttr>())
|
||||
return emitOpError("requires 'value' to be a string constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<FunctionType>(type)) {
|
||||
if (type.isa<FunctionType>()) {
|
||||
if (!value.isa<FunctionAttr>())
|
||||
return emitOpError("requires 'value' to be a function reference");
|
||||
return false;
|
||||
|
@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
|||
}
|
||||
|
||||
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
||||
const APFloat &value, FloatType *type) {
|
||||
const APFloat &value, FloatType type) {
|
||||
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
||||
}
|
||||
|
||||
bool ConstantFloatOp::isClassFor(const Operation *op) {
|
||||
return ConstantOp::isClassFor(op) &&
|
||||
isa<FloatType>(op->getResult(0)->getType());
|
||||
op->getResult(0)->getType().isa<FloatType>();
|
||||
}
|
||||
|
||||
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
||||
bool ConstantIntOp::isClassFor(const Operation *op) {
|
||||
return ConstantOp::isClassFor(op) &&
|
||||
isa<IntegerType>(op->getResult(0)->getType());
|
||||
op->getResult(0)->getType().isa<IntegerType>();
|
||||
}
|
||||
|
||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||
|
@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
|
|||
/// Build a constant int op producing an integer with the specified type,
|
||||
/// which must be an integer type.
|
||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||
int64_t value, Type *type) {
|
||||
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
|
||||
int64_t value, Type type) {
|
||||
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
|
||||
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
|
||||
}
|
||||
|
||||
/// ConstantIndexOp only matches values whose result type is Index.
|
||||
bool ConstantIndexOp::isClassFor(const Operation *op) {
|
||||
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
|
||||
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
|
||||
}
|
||||
|
||||
void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
||||
|
@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result,
|
|||
|
||||
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||
SmallVector<Type *, 2> types;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc;
|
||||
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
|
||||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
||||
|
@ -330,7 +330,7 @@ bool ReturnOp::verify() const {
|
|||
|
||||
// The operand number and types must match the function signature.
|
||||
MLFunction *function = cast<MLFunction>(block);
|
||||
const auto &results = function->getType()->getResults();
|
||||
const auto &results = function->getType().getResults();
|
||||
if (stmt->getNumOperands() != results.size())
|
||||
return emitOpError("has " + Twine(stmt->getNumOperands()) +
|
||||
" operands, but enclosing function returns " +
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
using namespace mlir;
|
||||
|
||||
Function::Function(Kind kind, Location *location, StringRef name,
|
||||
FunctionType *type, ArrayRef<NamedAttribute> attrs)
|
||||
: nameAndKind(Identifier::get(name, type->getContext()), kind),
|
||||
FunctionType type, ArrayRef<NamedAttribute> attrs)
|
||||
: nameAndKind(Identifier::get(name, type.getContext()), kind),
|
||||
location(location), type(type) {
|
||||
this->attrs = AttributeListStorage::get(attrs, getContext());
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
|
|||
return {};
|
||||
}
|
||||
|
||||
MLIRContext *Function::getContext() const { return getType()->getContext(); }
|
||||
MLIRContext *Function::getContext() const { return getType().getContext(); }
|
||||
|
||||
/// Delete this object.
|
||||
void Function::destroy() {
|
||||
|
@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const {
|
|||
// ExtFunction implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
|
||||
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::ExtFunc, location, name, type, attrs) {}
|
||||
|
||||
|
@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
|
|||
// CFGFunction implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
|
||||
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::CFGFunc, location, name, type, attrs) {}
|
||||
|
||||
|
@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() {
|
|||
|
||||
/// Create a new MLFunction with the specific fields.
|
||||
MLFunction *MLFunction::create(Location *location, StringRef name,
|
||||
FunctionType *type,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
const auto &argTypes = type->getInputs();
|
||||
const auto &argTypes = type.getInputs();
|
||||
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
|
||||
void *rawMem = malloc(byteSize);
|
||||
|
||||
|
@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
|
|||
return function;
|
||||
}
|
||||
|
||||
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
|
||||
MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::MLFunc, location, name, type, attrs),
|
||||
StmtBlock(StmtBlockKind::MLFunc) {}
|
||||
|
|
|
@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const {
|
|||
/// Create a new OperationInst with the specified fields.
|
||||
OperationInst *OperationInst::create(Location *location, OperationName name,
|
||||
ArrayRef<CFGValue *> operands,
|
||||
ArrayRef<Type *> resultTypes,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
MLIRContext *context) {
|
||||
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
|
||||
|
@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name,
|
|||
|
||||
OperationInst *OperationInst::clone() const {
|
||||
SmallVector<CFGValue *, 8> operands;
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
SmallVector<Type, 8> resultTypes;
|
||||
|
||||
// Put together the operands and results.
|
||||
for (auto *operand : getOperands())
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "AttributeDetail.h"
|
||||
#include "AttributeListStorage.h"
|
||||
#include "IntegerSetDetail.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -44,11 +45,11 @@ using namespace mlir::detail;
|
|||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
||||
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
|
||||
// Functions are uniqued based on their inputs and results.
|
||||
using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
|
||||
using DenseMapInfo<FunctionType *>::getHashValue;
|
||||
using DenseMapInfo<FunctionType *>::isEqual;
|
||||
using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
|
||||
using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
|
||||
using DenseMapInfo<FunctionTypeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
|
@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
|||
hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
|
||||
|
@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
|
|||
}
|
||||
};
|
||||
|
||||
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
|
||||
struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
|
||||
// Vectors are uniqued based on their element type and shape.
|
||||
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
||||
using DenseMapInfo<VectorType *>::getHashValue;
|
||||
using DenseMapInfo<VectorType *>::isEqual;
|
||||
using KeyTy = std::pair<Type, ArrayRef<int>>;
|
||||
using DenseMapInfo<VectorTypeStorage *>::getHashValue;
|
||||
using DenseMapInfo<VectorTypeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
DenseMapInfo<Type *>::getHashValue(key.first),
|
||||
DenseMapInfo<Type>::getHashValue(key.first),
|
||||
hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||
return lhs == KeyTy(rhs->elementType, rhs->getShape());
|
||||
}
|
||||
};
|
||||
|
||||
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
|
||||
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
|
||||
// Ranked tensors are uniqued based on their element type and shape.
|
||||
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
||||
using DenseMapInfo<RankedTensorType *>::getHashValue;
|
||||
using DenseMapInfo<RankedTensorType *>::isEqual;
|
||||
using KeyTy = std::pair<Type, ArrayRef<int>>;
|
||||
using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
|
||||
using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
DenseMapInfo<Type *>::getHashValue(key.first),
|
||||
DenseMapInfo<Type>::getHashValue(key.first),
|
||||
hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||
return lhs == KeyTy(rhs->elementType, rhs->getShape());
|
||||
}
|
||||
};
|
||||
|
||||
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
||||
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
|
||||
// MemRefs are uniqued based on their element type, shape, affine map
|
||||
// composition, and memory space.
|
||||
using KeyTy =
|
||||
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
|
||||
using DenseMapInfo<MemRefType *>::getHashValue;
|
||||
using DenseMapInfo<MemRefType *>::isEqual;
|
||||
using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
|
||||
using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
|
||||
using DenseMapInfo<MemRefTypeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
|
||||
DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
|
||||
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
|
||||
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
|
||||
std::get<3>(key));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
|
||||
rhs->getAffineMaps(), rhs->getMemorySpace());
|
||||
return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
|
||||
rhs->getAffineMaps(), rhs->memorySpace);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
|
|||
};
|
||||
|
||||
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
||||
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
|
||||
using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
|
||||
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
|
||||
|
||||
|
@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
|||
};
|
||||
|
||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
|
||||
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
||||
using KeyTy = std::pair<VectorOrTensorType, StringRef>;
|
||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
|
||||
|
||||
|
@ -295,13 +295,14 @@ public:
|
|||
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
|
||||
|
||||
// Uniquing table for 'other' types.
|
||||
OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
|
||||
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
|
||||
OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
|
||||
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
|
||||
nullptr};
|
||||
|
||||
// Uniquing table for 'float' types.
|
||||
FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
|
||||
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
|
||||
nullptr};
|
||||
FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
|
||||
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
|
||||
{nullptr};
|
||||
|
||||
// Affine map uniquing.
|
||||
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
|
||||
|
@ -324,26 +325,26 @@ public:
|
|||
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
|
||||
|
||||
/// Integer type uniquing.
|
||||
DenseMap<unsigned, IntegerType *> integers;
|
||||
DenseMap<unsigned, IntegerTypeStorage *> integers;
|
||||
|
||||
/// Function type uniquing.
|
||||
using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
|
||||
using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
|
||||
FunctionTypeSet functions;
|
||||
|
||||
/// Vector type uniquing.
|
||||
using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
|
||||
using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
|
||||
VectorTypeSet vectors;
|
||||
|
||||
/// Ranked tensor type uniquing.
|
||||
using RankedTensorTypeSet =
|
||||
DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
|
||||
DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
|
||||
RankedTensorTypeSet rankedTensors;
|
||||
|
||||
/// Unranked tensor type uniquing.
|
||||
DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
|
||||
DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
|
||||
|
||||
/// MemRef type uniquing.
|
||||
using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
|
||||
using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
|
||||
MemRefTypeSet memrefs;
|
||||
|
||||
// Attribute uniquing.
|
||||
|
@ -355,13 +356,12 @@ public:
|
|||
ArrayAttrSet arrayAttrs;
|
||||
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
|
||||
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
|
||||
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
|
||||
DenseMap<Type, TypeAttributeStorage *> typeAttrs;
|
||||
using AttributeListSet =
|
||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||
AttributeListSet attributeLists;
|
||||
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
|
||||
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
|
||||
SplatElementsAttributeStorage *>
|
||||
DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
|
||||
splatElementsAttrs;
|
||||
using DenseElementsAttrSet =
|
||||
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
|
||||
|
@ -369,7 +369,7 @@ public:
|
|||
using OpaqueElementsAttrSet =
|
||||
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
|
||||
OpaqueElementsAttrSet opaqueElementsAttrs;
|
||||
DenseMap<std::tuple<Type *, Attribute, Attribute>,
|
||||
DenseMap<std::tuple<Type, Attribute, Attribute>,
|
||||
SparseElementsAttributeStorage *>
|
||||
sparseElementsAttrs;
|
||||
|
||||
|
@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line,
|
|||
// Type uniquing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
|
||||
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
||||
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
auto *&result = impl.integers[width];
|
||||
if (!result) {
|
||||
result = impl.allocator.Allocate<IntegerType>();
|
||||
new (result) IntegerType(width, context);
|
||||
result = impl.allocator.Allocate<IntegerTypeStorage>();
|
||||
new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
FloatType *FloatType::get(Kind kind, MLIRContext *context) {
|
||||
FloatType FloatType::get(Kind kind, MLIRContext *context) {
|
||||
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
|
||||
auto &impl = context->getImpl();
|
||||
|
@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) {
|
|||
return entry;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *ptr = impl.allocator.Allocate<FloatType>();
|
||||
auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (ptr) FloatType(kind, context);
|
||||
new (ptr) FloatTypeStorage{{kind, context}};
|
||||
|
||||
// Cache and return it.
|
||||
return entry = ptr;
|
||||
}
|
||||
|
||||
OtherType *OtherType::get(Kind kind, MLIRContext *context) {
|
||||
OtherType OtherType::get(Kind kind, MLIRContext *context) {
|
||||
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
|
||||
"Not an 'other' type kind");
|
||||
auto &impl = context->getImpl();
|
||||
|
@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) {
|
|||
return entry;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *ptr = impl.allocator.Allocate<OtherType>();
|
||||
auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (ptr) OtherType(kind, context);
|
||||
new (ptr) OtherTypeStorage{{kind, context}};
|
||||
|
||||
// Cache and return it.
|
||||
return entry = ptr;
|
||||
}
|
||||
|
||||
FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
||||
ArrayRef<Type *> results,
|
||||
MLIRContext *context) {
|
||||
FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
|
||||
MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this function type.
|
||||
|
@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
|||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<FunctionType>();
|
||||
auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
|
||||
|
||||
// Copy the inputs and results into the bump pointer.
|
||||
SmallVector<Type *, 16> types;
|
||||
SmallVector<Type, 16> types;
|
||||
types.reserve(inputs.size() + results.size());
|
||||
types.append(inputs.begin(), inputs.end());
|
||||
types.append(results.begin(), results.end());
|
||||
auto typesList = impl.copyInto(ArrayRef<Type *>(types));
|
||||
auto typesList = impl.copyInto(ArrayRef<Type>(types));
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result)
|
||||
FunctionType(typesList.data(), inputs.size(), results.size(), context);
|
||||
new (result) FunctionTypeStorage{
|
||||
{Kind::Function, context, static_cast<unsigned int>(inputs.size())},
|
||||
static_cast<unsigned int>(results.size()),
|
||||
typesList.data()};
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
||||
VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
|
||||
assert(!shape.empty() && "vector types must have at least one dimension");
|
||||
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
|
||||
assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
|
||||
"vectors elements must be primitives");
|
||||
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
|
||||
return i < 0;
|
||||
}) && "vector types must have static shape");
|
||||
|
||||
auto *context = elementType->getContext();
|
||||
auto *context = elementType.getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this vector type.
|
||||
|
@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
|||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<VectorType>();
|
||||
auto *result = impl.allocator.Allocate<VectorTypeStorage>();
|
||||
|
||||
// Copy the shape into the bump pointer.
|
||||
shape = impl.copyInto(shape);
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) VectorType(shape, elementType, context);
|
||||
new (result) VectorTypeStorage{
|
||||
{{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
|
||||
elementType},
|
||||
shape.data()};
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
||||
Type *elementType) {
|
||||
auto *context = elementType->getContext();
|
||||
RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
|
||||
auto *context = elementType.getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this ranked tensor type.
|
||||
|
@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
|||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<RankedTensorType>();
|
||||
auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
|
||||
|
||||
// Copy the shape into the bump pointer.
|
||||
shape = impl.copyInto(shape);
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) RankedTensorType(shape, elementType, context);
|
||||
new (result) RankedTensorTypeStorage{
|
||||
{{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
|
||||
elementType}},
|
||||
shape.data()};
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
||||
auto *context = elementType->getContext();
|
||||
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||
auto *context = elementType.getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this unranked tensor type.
|
||||
|
@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
|||
return result;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
result = impl.allocator.Allocate<UnrankedTensorType>();
|
||||
result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) UnrankedTensorType(elementType, context);
|
||||
new (result) UnrankedTensorTypeStorage{
|
||||
{{{Kind::UnrankedTensor, context}, elementType}}};
|
||||
return result;
|
||||
}
|
||||
|
||||
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
auto *context = elementType->getContext();
|
||||
MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
auto *context = elementType.getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Drop the unbounded identity maps from the composition.
|
||||
|
@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
|||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<MemRefType>();
|
||||
auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
|
||||
|
||||
// Copy the shape into the bump pointer.
|
||||
shape = impl.copyInto(shape);
|
||||
|
@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
|||
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
|
||||
context);
|
||||
new (result) MemRefTypeStorage{
|
||||
{Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
|
||||
elementType,
|
||||
shape.data(),
|
||||
static_cast<unsigned int>(affineMapComposition.size()),
|
||||
affineMapComposition.data(),
|
||||
memorySpace};
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
|||
return result;
|
||||
}
|
||||
|
||||
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
|
||||
TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().typeAttrs[type];
|
||||
if (result)
|
||||
return result;
|
||||
|
@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
|
||||
Attribute elt) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||
|
@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
|||
return result;
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<char> data) {
|
||||
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
||||
auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
|
||||
(void)bitsRequired;
|
||||
assert((bitsRequired <= data.size() * 8L) &&
|
||||
"Input data bit size should be larger than that type requires");
|
||||
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if this constant is already defined.
|
||||
DenseElementsAttrInfo::KeyTy key({type, data});
|
||||
|
@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
|||
return *existing.first;
|
||||
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto *eltType = type->getElementType();
|
||||
switch (eltType->getKind()) {
|
||||
auto eltType = type.getElementType();
|
||||
switch (eltType.getKind()) {
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
|
@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
|||
return *existing.first = result;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
auto width = ::cast<IntegerType>(eltType)->getWidth();
|
||||
auto width = eltType.cast<IntegerType>().getWidth();
|
||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
|
@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
|||
}
|
||||
}
|
||||
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
|
||||
StringRef bytes) {
|
||||
assert(isValidTensorElementType(type->getElementType()) &&
|
||||
assert(isValidTensorElementType(type.getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if this constant is already defined.
|
||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||
|
@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
|
||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
auto key = std::make_tuple(type, indices, values);
|
||||
|
|
|
@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
|
|||
}
|
||||
|
||||
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
||||
auto *type = op->getResult(0)->getType();
|
||||
auto type = op->getResult(0)->getType();
|
||||
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
||||
if (op->getResult(i)->getType() != type)
|
||||
return op->emitOpError(
|
||||
|
@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
|||
|
||||
/// If this is a vector type, or a tensor type, return the scalar element type
|
||||
/// that it is built around, otherwise return the type unmodified.
|
||||
static Type *getTensorOrVectorElementType(Type *type) {
|
||||
if (auto *vec = dyn_cast<VectorType>(type))
|
||||
return vec->getElementType();
|
||||
static Type getTensorOrVectorElementType(Type type) {
|
||||
if (auto vec = type.dyn_cast<VectorType>())
|
||||
return vec.getElementType();
|
||||
|
||||
// Look through tensor<vector<...>> to find the underlying element type.
|
||||
if (auto *tensor = dyn_cast<TensorType>(type))
|
||||
return getTensorOrVectorElementType(tensor->getElementType());
|
||||
if (auto tensor = type.dyn_cast<TensorType>())
|
||||
return getTensorOrVectorElementType(tensor.getElementType());
|
||||
return type;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
|
||||
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
|
||||
return op->emitOpError("requires a floating point type");
|
||||
}
|
||||
|
||||
|
@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
|||
|
||||
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
|
||||
if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
|
||||
return op->emitOpError("requires an integer type");
|
||||
}
|
||||
return false;
|
||||
|
@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result,
|
|||
|
||||
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||
Type *type;
|
||||
Type type;
|
||||
return parser->parseOperandList(ops, 2) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type) ||
|
||||
|
@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
|||
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
|
||||
<< *op->getOperand(1);
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
*p << " : " << *op->getResult(0)->getType();
|
||||
*p << " : " << op->getResult(0)->getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void impl::buildCastOp(Builder *builder, OperationState *result,
|
||||
SSAValue *source, Type *destType) {
|
||||
SSAValue *source, Type destType) {
|
||||
result->addOperands(source);
|
||||
result->addTypes(destType);
|
||||
}
|
||||
|
||||
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
Type *srcType, *dstType;
|
||||
Type srcType, dstType;
|
||||
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
|
||||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
|
||||
parser->parseKeywordType("to", dstType) ||
|
||||
|
@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
|||
|
||||
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
|
||||
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
|
||||
<< *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
|
||||
<< op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
|
||||
}
|
||||
|
|
|
@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block,
|
|||
/// Create a new OperationStmt with the specific fields.
|
||||
OperationStmt *OperationStmt::create(Location *location, OperationName name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> resultTypes,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
MLIRContext *context) {
|
||||
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
|
||||
|
@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const {
|
|||
// If we have a result or operand type, that is a constant time way to get
|
||||
// to the context.
|
||||
if (getNumResults())
|
||||
return getResult(0)->getType()->getContext();
|
||||
return getResult(0)->getType().getContext();
|
||||
if (getNumOperands())
|
||||
return getOperand(0)->getType()->getContext();
|
||||
return getOperand(0)->getType().getContext();
|
||||
|
||||
// In the very odd case where we have no operands or results, fall back to
|
||||
// doing a find.
|
||||
|
@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const {
|
|||
if (operands.empty())
|
||||
return findFunction()->getContext();
|
||||
|
||||
return getOperand(0)->getType()->getContext();
|
||||
return getOperand(0)->getType().getContext();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
|||
operands.push_back(remapOperand(opValue));
|
||||
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
SmallVector<Type, 8> resultTypes;
|
||||
resultTypes.reserve(opStmt->getNumResults());
|
||||
for (auto *result : opStmt->getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This holds implementation details of Type.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TYPEDETAIL_H_
|
||||
#define TYPEDETAIL_H_
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class MLIRContext;
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Base storage class appearing in a Type.
|
||||
struct alignas(8) TypeStorage {
|
||||
TypeStorage(Type::Kind kind, MLIRContext *context)
|
||||
: context(context), kind(kind), subclassData(0) {}
|
||||
TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
|
||||
: context(context), kind(kind), subclassData(subclassData) {}
|
||||
|
||||
unsigned getSubclassData() const { return subclassData; }
|
||||
|
||||
void setSubclassData(unsigned val) {
|
||||
subclassData = val;
|
||||
// Ensure we don't have any accidental truncation.
|
||||
assert(getSubclassData() == val && "Subclass data too large for field");
|
||||
}
|
||||
|
||||
/// This refers to the MLIRContext in which this type was uniqued.
|
||||
MLIRContext *const context;
|
||||
|
||||
/// Classification of the subclass, used for type checking.
|
||||
Type::Kind kind : 8;
|
||||
|
||||
/// Space for subclasses to store data.
|
||||
unsigned subclassData : 24;
|
||||
};
|
||||
|
||||
struct IntegerTypeStorage : public TypeStorage {
|
||||
unsigned width;
|
||||
};
|
||||
|
||||
struct FloatTypeStorage : public TypeStorage {};
|
||||
|
||||
struct OtherTypeStorage : public TypeStorage {};
|
||||
|
||||
struct FunctionTypeStorage : public TypeStorage {
|
||||
ArrayRef<Type> getInputs() const {
|
||||
return ArrayRef<Type>(inputsAndResults, subclassData);
|
||||
}
|
||||
ArrayRef<Type> getResults() const {
|
||||
return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
|
||||
}
|
||||
|
||||
unsigned numResults;
|
||||
Type const *inputsAndResults;
|
||||
};
|
||||
|
||||
struct VectorOrTensorTypeStorage : public TypeStorage {
|
||||
Type elementType;
|
||||
};
|
||||
|
||||
struct VectorTypeStorage : public VectorOrTensorTypeStorage {
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
const int *shapeElements;
|
||||
};
|
||||
|
||||
struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
|
||||
|
||||
struct RankedTensorTypeStorage : public TensorTypeStorage {
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
const int *shapeElements;
|
||||
};
|
||||
|
||||
struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
|
||||
|
||||
struct MemRefTypeStorage : public TypeStorage {
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
ArrayRef<AffineMap> getAffineMaps() const {
|
||||
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
|
||||
}
|
||||
|
||||
/// The type of each scalar element of the memref.
|
||||
Type elementType;
|
||||
/// An array of integers which stores the shape dimension sizes.
|
||||
const int *shapeElements;
|
||||
/// The number of affine maps in the 'affineMapList' array.
|
||||
const unsigned numAffineMaps;
|
||||
/// List of affine maps in the memref's layout/index map composition.
|
||||
AffineMap const *affineMapList;
|
||||
/// Memory space in which data referenced by memref resides.
|
||||
const unsigned memorySpace;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace mlir
|
||||
#endif // TYPEDETAIL_H_
|
|
@ -16,10 +16,17 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
Type::Kind Type::getKind() const { return type->kind; }
|
||||
|
||||
MLIRContext *Type::getContext() const { return type->context; }
|
||||
|
||||
unsigned Type::getBitWidth() const {
|
||||
switch (getKind()) {
|
||||
|
@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const {
|
|||
case Type::Kind::F64:
|
||||
return 64;
|
||||
case Type::Kind::Integer:
|
||||
return cast<IntegerType>(this)->getWidth();
|
||||
return cast<IntegerType>().getWidth();
|
||||
case Type::Kind::Vector:
|
||||
case Type::Kind::RankedTensor:
|
||||
case Type::Kind::UnrankedTensor:
|
||||
return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
|
||||
return cast<VectorOrTensorType>().getElementType().getBitWidth();
|
||||
// TODO: Handle more types.
|
||||
default:
|
||||
llvm_unreachable("unexpected type");
|
||||
}
|
||||
}
|
||||
|
||||
IntegerType::IntegerType(unsigned width, MLIRContext *context)
|
||||
: Type(Kind::Integer, context), width(width) {
|
||||
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
|
||||
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
|
||||
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
|
||||
|
||||
IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
unsigned IntegerType::getWidth() const {
|
||||
return static_cast<ImplType *>(type)->width;
|
||||
}
|
||||
|
||||
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
|
||||
FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
|
||||
OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
||||
unsigned numResults, MLIRContext *context)
|
||||
: Type(Kind::Function, context, numInputs), numResults(numResults),
|
||||
inputsAndResults(inputsAndResults) {}
|
||||
FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
|
||||
Type *elementType, unsigned subClassData)
|
||||
: Type(kind, context, subClassData), elementType(elementType) {}
|
||||
ArrayRef<Type> FunctionType::getInputs() const {
|
||||
return static_cast<ImplType *>(type)->getInputs();
|
||||
}
|
||||
|
||||
unsigned FunctionType::getNumResults() const {
|
||||
return static_cast<ImplType *>(type)->numResults;
|
||||
}
|
||||
|
||||
ArrayRef<Type> FunctionType::getResults() const {
|
||||
return static_cast<ImplType *>(type)->getResults();
|
||||
}
|
||||
|
||||
VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
Type VectorOrTensorType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getNumElements() const {
|
||||
switch (getKind()) {
|
||||
|
@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
|
|||
ArrayRef<int> VectorOrTensorType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getShape();
|
||||
return cast<VectorType>().getShape();
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape();
|
||||
return cast<RankedTensorType>().getShape();
|
||||
case Kind::UnrankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape();
|
||||
return cast<RankedTensorType>().getShape();
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
|
@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const {
|
|||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
||||
shapeElements(shape.data()) {}
|
||||
VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
|
||||
|
||||
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
||||
: VectorOrTensorType(kind, context, elementType) {
|
||||
assert(isValidTensorElementType(elementType));
|
||||
ArrayRef<int> VectorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: TensorType(Kind::RankedTensor, elementType, context),
|
||||
shapeElements(shape.data()) {
|
||||
setSubclassData(shape.size());
|
||||
TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
|
||||
|
||||
RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
|
||||
|
||||
ArrayRef<int> RankedTensorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||
: TensorType(Kind::UnrankedTensor, elementType, context) {}
|
||||
UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
|
||||
|
||||
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
||||
MLIRContext *context)
|
||||
: Type(Kind::MemRef, context, shape.size()), elementType(elementType),
|
||||
shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
|
||||
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
|
||||
MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
|
||||
|
||||
ArrayRef<int> MemRefType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
Type MemRefType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
|
||||
return static_cast<ImplType *>(type)->getAffineMaps();
|
||||
}
|
||||
|
||||
unsigned MemRefType::getMemorySpace() const {
|
||||
return static_cast<ImplType *>(type)->memorySpace;
|
||||
}
|
||||
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
|
|
|
@ -182,19 +182,19 @@ public:
|
|||
// as the results of their action.
|
||||
|
||||
// Type parsing.
|
||||
VectorType *parseVectorType();
|
||||
VectorType parseVectorType();
|
||||
ParseResult parseXInDimensionList();
|
||||
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
|
||||
Type *parseTensorType();
|
||||
Type *parseMemRefType();
|
||||
Type *parseFunctionType();
|
||||
Type *parseType();
|
||||
ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
|
||||
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
|
||||
Type parseTensorType();
|
||||
Type parseMemRefType();
|
||||
Type parseFunctionType();
|
||||
Type parseType();
|
||||
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
|
||||
ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
|
||||
|
||||
// Attribute parsing.
|
||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType *type);
|
||||
FunctionType type);
|
||||
Attribute parseAttribute();
|
||||
|
||||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||
|
@ -206,9 +206,9 @@ public:
|
|||
AffineMap parseAffineMapReference();
|
||||
IntegerSet parseIntegerSetInline();
|
||||
IntegerSet parseIntegerSetReference();
|
||||
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
|
||||
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
|
||||
VectorOrTensorType *parseVectorOrTensorType();
|
||||
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
|
||||
DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
|
||||
VectorOrTensorType parseVectorOrTensorType();
|
||||
|
||||
private:
|
||||
// The Parser is subclassed and reinstantiated. Do not add additional
|
||||
|
@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
|||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||
/// other-type ::= `index` | `tf_control`
|
||||
///
|
||||
Type *Parser::parseType() {
|
||||
Type Parser::parseType() {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
return (emitError("expected type"), nullptr);
|
||||
|
@ -368,7 +368,7 @@ Type *Parser::parseType() {
|
|||
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
||||
/// const-dimension-list ::= (integer-literal `x`)+
|
||||
///
|
||||
VectorType *Parser::parseVectorType() {
|
||||
VectorType Parser::parseVectorType() {
|
||||
consumeToken(Token::kw_vector);
|
||||
|
||||
if (parseToken(Token::less, "expected '<' in vector type"))
|
||||
|
@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() {
|
|||
|
||||
// Parse the element type.
|
||||
auto typeLoc = getToken().getLoc();
|
||||
auto *elementType = parseType();
|
||||
auto elementType = parseType();
|
||||
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
|
||||
return nullptr;
|
||||
|
||||
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
|
||||
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
|
||||
return (emitError(typeLoc, "invalid vector element type"), nullptr);
|
||||
|
||||
return VectorType::get(dimensions, elementType);
|
||||
|
@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
|
|||
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
|
||||
/// dimension-list ::= dimension-list-ranked | `*x`
|
||||
///
|
||||
Type *Parser::parseTensorType() {
|
||||
Type Parser::parseTensorType() {
|
||||
consumeToken(Token::kw_tensor);
|
||||
|
||||
if (parseToken(Token::less, "expected '<' in tensor type"))
|
||||
|
@ -485,7 +485,7 @@ Type *Parser::parseTensorType() {
|
|||
|
||||
// Parse the element type.
|
||||
auto typeLoc = getToken().getLoc();
|
||||
auto *elementType = parseType();
|
||||
auto elementType = parseType();
|
||||
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
|
||||
return nullptr;
|
||||
|
||||
|
@ -505,7 +505,7 @@ Type *Parser::parseTensorType() {
|
|||
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
||||
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
||||
///
|
||||
Type *Parser::parseMemRefType() {
|
||||
Type Parser::parseMemRefType() {
|
||||
consumeToken(Token::kw_memref);
|
||||
|
||||
if (parseToken(Token::less, "expected '<' in memref type"))
|
||||
|
@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() {
|
|||
|
||||
// Parse the element type.
|
||||
auto typeLoc = getToken().getLoc();
|
||||
auto *elementType = parseType();
|
||||
auto elementType = parseType();
|
||||
if (!elementType)
|
||||
return nullptr;
|
||||
|
||||
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
|
||||
!isa<VectorType>(elementType))
|
||||
if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
|
||||
!elementType.isa<VectorType>())
|
||||
return (emitError(typeLoc, "invalid memref element type"), nullptr);
|
||||
|
||||
// Parse semi-affine-map-composition.
|
||||
|
@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() {
|
|||
///
|
||||
/// function-type ::= type-list-parens `->` type-list
|
||||
///
|
||||
Type *Parser::parseFunctionType() {
|
||||
Type Parser::parseFunctionType() {
|
||||
assert(getToken().is(Token::l_paren));
|
||||
|
||||
SmallVector<Type *, 4> arguments, results;
|
||||
SmallVector<Type, 4> arguments, results;
|
||||
if (parseTypeList(arguments) ||
|
||||
parseToken(Token::arrow, "expected '->' in function type") ||
|
||||
parseTypeList(results))
|
||||
|
@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() {
|
|||
///
|
||||
/// type-list-no-parens ::= type (`,` type)*
|
||||
///
|
||||
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
|
||||
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
auto elt = parseType();
|
||||
elements.push_back(elt);
|
||||
|
@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
|
|||
/// type-list-parens ::= `(` `)`
|
||||
/// | `(` type-list-no-parens `)`
|
||||
///
|
||||
ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
|
||||
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
auto elt = parseType();
|
||||
elements.push_back(elt);
|
||||
|
@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
|
|||
namespace {
|
||||
class TensorLiteralParser {
|
||||
public:
|
||||
TensorLiteralParser(Parser &p, Type *eltTy)
|
||||
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
|
||||
TensorLiteralParser(Parser &p, Type eltTy)
|
||||
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
|
||||
|
||||
ParseResult parse() { return parseList(shape); }
|
||||
|
||||
|
@ -676,7 +676,7 @@ private:
|
|||
}
|
||||
|
||||
Parser &p;
|
||||
Type *eltTy;
|
||||
Type eltTy;
|
||||
size_t currBitPos;
|
||||
size_t bitsWidth;
|
||||
SmallVector<int, 4> shape;
|
||||
|
@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
if (!result)
|
||||
return p.emitError("expected tensor element");
|
||||
// check result matches the element type.
|
||||
switch (eltTy->getKind()) {
|
||||
switch (eltTy.getKind()) {
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
|
@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) {
|
|||
/// synthesizing a forward reference) or emit an error and return null on
|
||||
/// failure.
|
||||
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType *type) {
|
||||
FunctionType type) {
|
||||
Identifier name = builder.getIdentifier(nameStr.drop_front());
|
||||
|
||||
// See if the function has already been defined in the module.
|
||||
|
@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() {
|
|||
if (parseToken(Token::colon, "expected ':' and function type"))
|
||||
return nullptr;
|
||||
auto typeLoc = getToken().getLoc();
|
||||
Type *type = parseType();
|
||||
Type type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto *fnType = dyn_cast<FunctionType>(type);
|
||||
auto fnType = type.dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||
|
||||
|
@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() {
|
|||
consumeToken(Token::kw_opaque);
|
||||
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
||||
return nullptr;
|
||||
auto *type = parseVectorOrTensorType();
|
||||
auto type = parseVectorOrTensorType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto val = getToken().getStringValue();
|
||||
|
@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() {
|
|||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||
return nullptr;
|
||||
|
||||
auto *type = parseVectorOrTensorType();
|
||||
auto type = parseVectorOrTensorType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
switch (getToken().getKind()) {
|
||||
|
@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() {
|
|||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
|
||||
auto *type = parseVectorOrTensorType();
|
||||
auto type = parseVectorOrTensorType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
|
@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() {
|
|||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||
return nullptr;
|
||||
|
||||
auto *type = parseVectorOrTensorType();
|
||||
auto type = parseVectorOrTensorType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
switch (getToken().getKind()) {
|
||||
case Token::l_square: {
|
||||
/// Parse indices
|
||||
auto *indicesEltType = builder.getIntegerType(32);
|
||||
auto indicesEltType = builder.getIntegerType(32);
|
||||
auto indices =
|
||||
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
|
||||
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
|
||||
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
/// Parse values.
|
||||
auto *valuesEltType = type->getElementType();
|
||||
auto valuesEltType = type.getElementType();
|
||||
auto values =
|
||||
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
|
||||
parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
|
||||
|
||||
/// Sanity check.
|
||||
auto *indicesType = indices.getType();
|
||||
auto *valuesType = values.getType();
|
||||
auto sameShape = (indicesType->getRank() == 1) ||
|
||||
(type->getRank() == indicesType->getDimSize(1));
|
||||
auto indicesType = indices.getType();
|
||||
auto valuesType = values.getType();
|
||||
auto sameShape = (indicesType.getRank() == 1) ||
|
||||
(type.getRank() == indicesType.getDimSize(1));
|
||||
auto sameElementNum =
|
||||
indicesType->getDimSize(0) == valuesType->getDimSize(0);
|
||||
indicesType.getDimSize(0) == valuesType.getDimSize(0);
|
||||
if (!sameShape || !sameElementNum) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream s(str);
|
||||
s << "expected shape ([";
|
||||
interleaveComma(type->getShape(), s);
|
||||
interleaveComma(type.getShape(), s);
|
||||
s << "]); inferred shape of indices literal ([";
|
||||
interleaveComma(indicesType->getShape(), s);
|
||||
interleaveComma(indicesType.getShape(), s);
|
||||
s << "]); inferred shape of values literal ([";
|
||||
interleaveComma(valuesType->getShape(), s);
|
||||
interleaveComma(valuesType.getShape(), s);
|
||||
s << "])";
|
||||
return (emitError(s.str()), nullptr);
|
||||
}
|
||||
|
@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() {
|
|||
nullptr);
|
||||
}
|
||||
default: {
|
||||
if (Type *type = parseType())
|
||||
if (Type type = parseType())
|
||||
return builder.getTypeAttr(type);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() {
|
|||
///
|
||||
/// This method returns a constructed dense elements attribute with the shape
|
||||
/// from the parsing result.
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
|
||||
TensorLiteralParser literalParser(*this, eltType);
|
||||
if (literalParser.parse())
|
||||
return nullptr;
|
||||
|
||||
VectorOrTensorType *type;
|
||||
VectorOrTensorType type;
|
||||
if (isVector) {
|
||||
type = builder.getVectorType(literalParser.getShape(), eltType);
|
||||
} else {
|
||||
|
@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
|||
/// This method compares the shapes from the parsing result and that from the
|
||||
/// input argument. It returns a constructed dense elements attribute if both
|
||||
/// match.
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||
auto *eltTy = type->getElementType();
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
|
||||
auto eltTy = type.getElementType();
|
||||
TensorLiteralParser literalParser(*this, eltTy);
|
||||
if (literalParser.parse())
|
||||
return nullptr;
|
||||
if (literalParser.getShape() != type->getShape()) {
|
||||
if (literalParser.getShape() != type.getShape()) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream s(str);
|
||||
s << "inferred shape of elements literal ([";
|
||||
interleaveComma(literalParser.getShape(), s);
|
||||
s << "]) does not match type ([";
|
||||
interleaveComma(type->getShape(), s);
|
||||
interleaveComma(type.getShape(), s);
|
||||
s << "])";
|
||||
return (emitError(s.str()), nullptr);
|
||||
}
|
||||
|
@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
|||
/// vector-or-tensor-type ::= vector-type | tensor-type
|
||||
///
|
||||
/// This method also checks the type has static shape and ranked.
|
||||
VectorOrTensorType *Parser::parseVectorOrTensorType() {
|
||||
auto *type = dyn_cast<VectorOrTensorType>(parseType());
|
||||
VectorOrTensorType Parser::parseVectorOrTensorType() {
|
||||
auto type = parseType().dyn_cast<VectorOrTensorType>();
|
||||
if (!type) {
|
||||
return (emitError("expected elements literal has a tensor or vector type"),
|
||||
nullptr);
|
||||
|
@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() {
|
|||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
if (!type->hasStaticShape() || type->getRank() == -1) {
|
||||
if (!type.hasStaticShape() || type.getRank() == -1) {
|
||||
return (emitError("tensor literals must be ranked and have static shape"),
|
||||
nullptr);
|
||||
}
|
||||
|
@ -1834,7 +1834,7 @@ public:
|
|||
|
||||
/// Given a reference to an SSA value and its type, return a reference. This
|
||||
/// returns null on failure.
|
||||
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
|
||||
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
|
||||
|
||||
/// Register a definition of a value with the symbol table.
|
||||
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
|
||||
|
@ -1845,11 +1845,11 @@ public:
|
|||
|
||||
template <typename ResultType>
|
||||
ResultType parseSSADefOrUseAndType(
|
||||
const std::function<ResultType(SSAUseInfo, Type *)> &action);
|
||||
const std::function<ResultType(SSAUseInfo, Type)> &action);
|
||||
|
||||
SSAValue *parseSSAUseAndType() {
|
||||
return parseSSADefOrUseAndType<SSAValue *>(
|
||||
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
|
||||
[&](SSAUseInfo useInfo, Type type) -> SSAValue * {
|
||||
return resolveSSAUse(useInfo, type);
|
||||
});
|
||||
}
|
||||
|
@ -1880,7 +1880,7 @@ private:
|
|||
/// their first reference, to allow checking for use of undefined values.
|
||||
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
|
||||
|
||||
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
|
||||
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
|
||||
|
||||
/// Return true if this is a forward reference.
|
||||
bool isForwardReferencePlaceholder(SSAValue *value) {
|
||||
|
@ -1891,7 +1891,7 @@ private:
|
|||
|
||||
/// Create and remember a new placeholder for a forward reference.
|
||||
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
||||
Type *type) {
|
||||
Type type) {
|
||||
// Forward references are always created as instructions, even in ML
|
||||
// functions, because we just need something with a def/use chain.
|
||||
//
|
||||
|
@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
|||
|
||||
/// Given an unbound reference to an SSA value and its type, return the value
|
||||
/// it specifies. This returns null on failure.
|
||||
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
|
||||
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
|
||||
auto &entries = values[useInfo.name];
|
||||
|
||||
// If we have already seen a value of this name, return it.
|
||||
|
@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
|
|||
/// ssa-use-and-type ::= ssa-use `:` type
|
||||
template <typename ResultType>
|
||||
ResultType FunctionParser::parseSSADefOrUseAndType(
|
||||
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
|
||||
const std::function<ResultType(SSAUseInfo, Type)> &action) {
|
||||
|
||||
SSAUseInfo useInfo;
|
||||
if (parseSSAUse(useInfo) ||
|
||||
parseToken(Token::colon, "expected ':' and type for SSA operand"))
|
||||
return nullptr;
|
||||
|
||||
auto *type = parseType();
|
||||
auto type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
|
@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
|
|||
if (valueIDs.empty())
|
||||
return ParseSuccess;
|
||||
|
||||
SmallVector<Type *, 4> types;
|
||||
SmallVector<Type, 4> types;
|
||||
if (parseToken(Token::colon, "expected ':' in operand list") ||
|
||||
parseTypeListNoParens(types))
|
||||
return ParseFailure;
|
||||
|
@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation(
|
|||
auto type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto fnType = dyn_cast<FunctionType>(type);
|
||||
auto fnType = type.dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||
|
||||
result.addTypes(fnType->getResults());
|
||||
result.addTypes(fnType.getResults());
|
||||
|
||||
// Check that we have the right number of types for the operands.
|
||||
auto operandTypes = fnType->getInputs();
|
||||
auto operandTypes = fnType.getInputs();
|
||||
if (operandTypes.size() != operandInfos.size()) {
|
||||
auto plural = "s"[operandInfos.size() == 1];
|
||||
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
|
||||
|
@ -2253,17 +2253,17 @@ public:
|
|||
return parser.parseToken(Token::comma, "expected ','");
|
||||
}
|
||||
|
||||
bool parseColonType(Type *&result) override {
|
||||
bool parseColonType(Type &result) override {
|
||||
return parser.parseToken(Token::colon, "expected ':'") ||
|
||||
!(result = parser.parseType());
|
||||
}
|
||||
|
||||
bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
|
||||
bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
|
||||
if (parser.parseToken(Token::colon, "expected ':'"))
|
||||
return true;
|
||||
|
||||
do {
|
||||
if (auto *type = parser.parseType())
|
||||
if (auto type = parser.parseType())
|
||||
result.push_back(type);
|
||||
else
|
||||
return true;
|
||||
|
@ -2273,7 +2273,7 @@ public:
|
|||
}
|
||||
|
||||
/// Parse a keyword followed by a type.
|
||||
bool parseKeywordType(const char *keyword, Type *&result) override {
|
||||
bool parseKeywordType(const char *keyword, Type &result) override {
|
||||
if (parser.getTokenSpelling() != keyword)
|
||||
return parser.emitError("expected '" + Twine(keyword) + "'");
|
||||
parser.consumeToken();
|
||||
|
@ -2396,7 +2396,7 @@ public:
|
|||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType type,
|
||||
llvm::SMLoc loc, Function *&result) {
|
||||
result = parser.resolveFunctionReference(name, loc, type);
|
||||
return result == nullptr;
|
||||
|
@ -2410,7 +2410,7 @@ public:
|
|||
|
||||
llvm::SMLoc getNameLoc() const override { return nameLoc; }
|
||||
|
||||
bool resolveOperand(const OperandType &operand, Type *type,
|
||||
bool resolveOperand(const OperandType &operand, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) override {
|
||||
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
|
||||
operand.location};
|
||||
|
@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
|
|||
return ParseSuccess;
|
||||
|
||||
return parseCommaSeparatedList([&]() -> ParseResult {
|
||||
auto type = parseSSADefOrUseAndType<Type *>(
|
||||
[&](SSAUseInfo useInfo, Type *type) -> Type * {
|
||||
auto type = parseSSADefOrUseAndType<Type>(
|
||||
[&](SSAUseInfo useInfo, Type type) -> Type {
|
||||
BBArgument *arg = owner->addArgument(type);
|
||||
if (addDefinition(useInfo, arg))
|
||||
return nullptr;
|
||||
return {};
|
||||
return type;
|
||||
});
|
||||
return type ? ParseSuccess : ParseFailure;
|
||||
|
@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
|||
" symbol count must match");
|
||||
|
||||
// Resolve SSA uses.
|
||||
Type *indexType = builder.getIndexType();
|
||||
Type indexType = builder.getIndexType();
|
||||
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
|
||||
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
|
||||
if (!sval)
|
||||
|
@ -3187,9 +3187,9 @@ private:
|
|||
ParseResult parseAffineStructureDef();
|
||||
|
||||
// Functions.
|
||||
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
||||
ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames);
|
||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||
SmallVectorImpl<StringRef> *argNames);
|
||||
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
|
||||
ParseResult parseExtFunc();
|
||||
|
@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() {
|
|||
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
|
||||
///
|
||||
ParseResult
|
||||
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
||||
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames) {
|
||||
consumeToken(Token::l_paren);
|
||||
|
||||
|
@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
|||
/// type-list)?
|
||||
///
|
||||
ParseResult
|
||||
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||
SmallVectorImpl<StringRef> *argNames) {
|
||||
if (getToken().isNot(Token::at_identifier))
|
||||
return emitError("expected a function identifier like '@foo'");
|
||||
|
@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
|||
if (getToken().isNot(Token::l_paren))
|
||||
return emitError("expected '(' in function signature");
|
||||
|
||||
SmallVector<Type *, 4> argTypes;
|
||||
SmallVector<Type, 4> argTypes;
|
||||
ParseResult parseResult;
|
||||
|
||||
if (argNames)
|
||||
|
@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
|||
return ParseFailure;
|
||||
|
||||
// Parse the return type if present.
|
||||
SmallVector<Type *, 4> results;
|
||||
SmallVector<Type, 4> results;
|
||||
if (consumeIf(Token::arrow)) {
|
||||
if (parseTypeList(results))
|
||||
return ParseFailure;
|
||||
|
@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() {
|
|||
auto loc = getToken().getLoc();
|
||||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
FunctionType type;
|
||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() {
|
|||
auto loc = getToken().getLoc();
|
||||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
FunctionType type;
|
||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() {
|
|||
consumeToken(Token::kw_mlfunc);
|
||||
|
||||
StringRef name;
|
||||
FunctionType *type = nullptr;
|
||||
FunctionType type;
|
||||
SmallVector<StringRef, 4> argNames;
|
||||
|
||||
auto loc = getToken().getLoc();
|
||||
|
|
|
@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AllocOp::build(Builder *builder, OperationState *result,
|
||||
MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
|
||||
MemRefType memrefType, ArrayRef<SSAValue *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->types.push_back(memrefType);
|
||||
}
|
||||
|
||||
void AllocOp::print(OpAsmPrinter *p) const {
|
||||
MemRefType *type = getType();
|
||||
MemRefType type = getType();
|
||||
*p << "alloc";
|
||||
// Print dynamic dimension operands.
|
||||
printDimAndSymbolList(operand_begin(), operand_end(),
|
||||
type->getNumDynamicDims(), p);
|
||||
type.getNumDynamicDims(), p);
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
||||
*p << " : " << *type;
|
||||
*p << " : " << type;
|
||||
}
|
||||
|
||||
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
MemRefType *type;
|
||||
MemRefType type;
|
||||
|
||||
// Parse the dimension operands and optional symbol operands, followed by a
|
||||
// memref type.
|
||||
|
@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
// Verification still checks that the total number of operands matches
|
||||
// the number of symbols in the affine map, plus the number of dynamic
|
||||
// dimensions in the memref.
|
||||
if (numDimOperands != type->getNumDynamicDims()) {
|
||||
if (numDimOperands != type.getNumDynamicDims()) {
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"dimension operand count does not equal memref "
|
||||
"dynamic dimension count");
|
||||
|
@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
}
|
||||
|
||||
bool AllocOp::verify() const {
|
||||
auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
|
||||
auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return emitOpError("result must be a memref");
|
||||
|
||||
unsigned numSymbols = 0;
|
||||
if (!memRefType->getAffineMaps().empty()) {
|
||||
AffineMap affineMap = memRefType->getAffineMaps()[0];
|
||||
if (!memRefType.getAffineMaps().empty()) {
|
||||
AffineMap affineMap = memRefType.getAffineMaps()[0];
|
||||
// Store number of symbols used in affine map (used in subsequent check).
|
||||
numSymbols = affineMap.getNumSymbols();
|
||||
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
|
||||
|
@ -195,10 +195,10 @@ bool AllocOp::verify() const {
|
|||
// Remove when we can emit errors directly from *Type::get(...) functions.
|
||||
//
|
||||
// Verify that the layout affine map matches the rank of the memref.
|
||||
if (affineMap.getNumDims() != memRefType->getRank())
|
||||
if (affineMap.getNumDims() != memRefType.getRank())
|
||||
return emitOpError("affine map dimension count must equal memref rank");
|
||||
}
|
||||
unsigned numDynamicDims = memRefType->getNumDynamicDims();
|
||||
unsigned numDynamicDims = memRefType.getNumDynamicDims();
|
||||
// Check that the total number of operands matches the number of symbols in
|
||||
// the affine map, plus the number of dynamic dimensions specified in the
|
||||
// memref type.
|
||||
|
@ -208,7 +208,7 @@ bool AllocOp::verify() const {
|
|||
}
|
||||
// Verify that all operands are of type Index.
|
||||
for (auto *operand : getOperands()) {
|
||||
if (!operand->getType()->isIndex())
|
||||
if (!operand->getType().isIndex())
|
||||
return emitOpError("requires operands to be of type Index");
|
||||
}
|
||||
return false;
|
||||
|
@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern {
|
|||
// Ok, we have one or more constant operands. Collect the non-constant ones
|
||||
// and keep track of the resultant memref type to build.
|
||||
SmallVector<int, 4> newShapeConstants;
|
||||
newShapeConstants.reserve(memrefType->getRank());
|
||||
newShapeConstants.reserve(memrefType.getRank());
|
||||
SmallVector<SSAValue *, 4> newOperands;
|
||||
SmallVector<SSAValue *, 4> droppedOperands;
|
||||
|
||||
unsigned dynamicDimPos = 0;
|
||||
for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
|
||||
int dimSize = memrefType->getDimSize(dim);
|
||||
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
||||
int dimSize = memrefType.getDimSize(dim);
|
||||
// If this is already static dimension, keep it.
|
||||
if (dimSize != -1) {
|
||||
newShapeConstants.push_back(dimSize);
|
||||
|
@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern {
|
|||
}
|
||||
|
||||
// Create new memref type (which will have fewer dynamic dimensions).
|
||||
auto *newMemRefType = MemRefType::get(
|
||||
newShapeConstants, memrefType->getElementType(),
|
||||
memrefType->getAffineMaps(), memrefType->getMemorySpace());
|
||||
assert(newOperands.size() == newMemRefType->getNumDynamicDims());
|
||||
auto newMemRefType = MemRefType::get(
|
||||
newShapeConstants, memrefType.getElementType(),
|
||||
memrefType.getAffineMaps(), memrefType.getMemorySpace());
|
||||
assert(newOperands.size() == newMemRefType.getNumDynamicDims());
|
||||
|
||||
// Create and insert the alloc op for the new memref.
|
||||
auto newAlloc =
|
||||
|
@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
|||
ArrayRef<SSAValue *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType()->getResults());
|
||||
result->addTypes(callee->getType().getResults());
|
||||
}
|
||||
|
||||
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
StringRef calleeName;
|
||||
llvm::SMLoc calleeLoc;
|
||||
FunctionType *calleeType = nullptr;
|
||||
FunctionType calleeType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
Function *callee = nullptr;
|
||||
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
||||
|
@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(calleeType) ||
|
||||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
||||
parser->addTypesToList(calleeType->getResults(), result->types) ||
|
||||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
|
||||
parser->addTypesToList(calleeType.getResults(), result->types) ||
|
||||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
||||
result->operands))
|
||||
return true;
|
||||
|
||||
|
@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getOperands());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||
*p << " : " << *getCallee()->getType();
|
||||
*p << " : " << getCallee()->getType();
|
||||
}
|
||||
|
||||
bool CallOp::verify() const {
|
||||
|
@ -338,20 +338,20 @@ bool CallOp::verify() const {
|
|||
return emitOpError("requires a 'callee' function attribute");
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto *fnType = fnAttr.getValue()->getType();
|
||||
if (fnType->getNumInputs() != getNumOperands())
|
||||
auto fnType = fnAttr.getValue()->getType();
|
||||
if (fnType.getNumInputs() != getNumOperands())
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i)->getType() != fnType->getInput(i))
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i)->getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch");
|
||||
}
|
||||
|
||||
if (fnType->getNumResults() != getNumResults())
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType->getResult(i))
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType.getResult(i))
|
||||
return emitOpError("result type mismatch");
|
||||
}
|
||||
|
||||
|
@ -364,14 +364,14 @@ bool CallOp::verify() const {
|
|||
|
||||
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
||||
auto *fnType = cast<FunctionType>(callee->getType());
|
||||
auto fnType = callee->getType().cast<FunctionType>();
|
||||
result->operands.push_back(callee);
|
||||
result->addOperands(operands);
|
||||
result->addTypes(fnType->getResults());
|
||||
result->addTypes(fnType.getResults());
|
||||
}
|
||||
|
||||
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
FunctionType *calleeType = nullptr;
|
||||
FunctionType calleeType;
|
||||
OpAsmParser::OperandType callee;
|
||||
llvm::SMLoc operandsLoc;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
|
@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(calleeType) ||
|
||||
parser->resolveOperand(callee, calleeType, result->operands) ||
|
||||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
|
||||
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
|
||||
result->operands) ||
|
||||
parser->addTypesToList(calleeType->getResults(), result->types);
|
||||
parser->addTypesToList(calleeType.getResults(), result->types);
|
||||
}
|
||||
|
||||
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
||||
|
@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(++operandRange.begin(), operandRange.end());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||
*p << " : " << *getCallee()->getType();
|
||||
*p << " : " << getCallee()->getType();
|
||||
}
|
||||
|
||||
bool CallIndirectOp::verify() const {
|
||||
// The callee must be a function.
|
||||
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
|
||||
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return emitOpError("callee must have function type");
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
if (fnType->getNumInputs() != getNumOperands() - 1)
|
||||
if (fnType.getNumInputs() != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i + 1)->getType() != fnType->getInput(i))
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i + 1)->getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch");
|
||||
}
|
||||
|
||||
if (fnType->getNumResults() != getNumResults())
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType->getResult(i))
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType.getResult(i))
|
||||
return emitOpError("result type mismatch");
|
||||
}
|
||||
|
||||
|
@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result,
|
|||
}
|
||||
|
||||
void DeallocOp::print(OpAsmPrinter *p) const {
|
||||
*p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
|
||||
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
|
||||
}
|
||||
|
||||
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
MemRefType *type;
|
||||
MemRefType type;
|
||||
|
||||
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
|
||||
parser->resolveOperand(memrefInfo, type, result->operands);
|
||||
}
|
||||
|
||||
bool DeallocOp::verify() const {
|
||||
if (!isa<MemRefType>(getMemRef()->getType()))
|
||||
if (!getMemRef()->getType().isa<MemRefType>())
|
||||
return emitOpError("operand must be a memref");
|
||||
return false;
|
||||
}
|
||||
|
@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result,
|
|||
void DimOp::print(OpAsmPrinter *p) const {
|
||||
*p << "dim " << *getOperand() << ", " << getIndex();
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
||||
*p << " : " << *getOperand()->getType();
|
||||
*p << " : " << getOperand()->getType();
|
||||
}
|
||||
|
||||
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
IntegerAttr indexAttr;
|
||||
Type *type;
|
||||
Type type;
|
||||
|
||||
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
||||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
|
||||
|
@ -496,15 +496,15 @@ bool DimOp::verify() const {
|
|||
return emitOpError("requires an integer attribute named 'index'");
|
||||
uint64_t index = (uint64_t)indexAttr.getValue();
|
||||
|
||||
auto *type = getOperand()->getType();
|
||||
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||
if (index >= tensorType->getRank())
|
||||
auto type = getOperand()->getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
if (index >= tensorType.getRank())
|
||||
return emitOpError("index is out of range");
|
||||
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
|
||||
if (index >= memrefType->getRank())
|
||||
} else if (auto memrefType = type.dyn_cast<MemRefType>()) {
|
||||
if (index >= memrefType.getRank())
|
||||
return emitOpError("index is out of range");
|
||||
|
||||
} else if (isa<UnrankedTensorType>(type)) {
|
||||
} else if (type.isa<UnrankedTensorType>()) {
|
||||
// ok, assumed to be in-range.
|
||||
} else {
|
||||
return emitOpError("requires an operand with tensor or memref type");
|
||||
|
@ -516,12 +516,12 @@ bool DimOp::verify() const {
|
|||
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
// Constant fold dim when the size along the index referred to is a constant.
|
||||
auto *opType = getOperand()->getType();
|
||||
auto opType = getOperand()->getType();
|
||||
int indexSize = -1;
|
||||
if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
|
||||
indexSize = tensorType->getShape()[getIndex()];
|
||||
} else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
|
||||
indexSize = memrefType->getShape()[getIndex()];
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
indexSize = tensorType.getShape()[getIndex()];
|
||||
} else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
|
||||
indexSize = memrefType.getShape()[getIndex()];
|
||||
}
|
||||
|
||||
if (indexSize >= 0)
|
||||
|
@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getTagIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << *getSrcMemRef()->getType();
|
||||
*p << ", " << *getDstMemRef()->getType();
|
||||
*p << ", " << *getTagMemRef()->getType();
|
||||
*p << " : " << getSrcMemRef()->getType();
|
||||
*p << ", " << getDstMemRef()->getType();
|
||||
*p << ", " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse DmaStartOp.
|
||||
|
@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
OpAsmParser::OperandType tagMemrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
|
||||
|
||||
SmallVector<Type *, 3> types;
|
||||
auto *indexType = parser->getBuilder().getIndexType();
|
||||
SmallVector<Type, 3> types;
|
||||
auto indexType = parser->getBuilder().getIndexType();
|
||||
|
||||
// Parse and resolve the following list of operands:
|
||||
// *) source memref followed by its indices (in square brackets).
|
||||
|
@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
return true;
|
||||
|
||||
// Check that source/destination index list size matches associated rank.
|
||||
if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
|
||||
dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
|
||||
if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
|
||||
dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"memref rank not equal to indices count");
|
||||
|
||||
if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
|
||||
if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"tag memref rank not equal to indices count");
|
||||
|
||||
|
@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getTagIndices());
|
||||
*p << "], ";
|
||||
p->printOperand(getNumElements());
|
||||
*p << " : " << *getTagMemRef()->getType();
|
||||
*p << " : " << getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse DmaWaitOp.
|
||||
|
@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
|
|||
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType tagMemrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
||||
Type *type;
|
||||
auto *indexType = parser->getBuilder().getIndexType();
|
||||
Type type;
|
||||
auto indexType = parser->getBuilder().getIndexType();
|
||||
OpAsmParser::OperandType numElementsInfo;
|
||||
|
||||
// Parse tag memref, its indices, and dma size.
|
||||
|
@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||
return true;
|
||||
|
||||
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
|
||||
if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"tag memref rank not equal to indices count");
|
||||
|
||||
|
@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
|
|||
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *aggregate,
|
||||
ArrayRef<SSAValue *> indices) {
|
||||
auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
|
||||
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
|
||||
result->addOperands(aggregate);
|
||||
result->addOperands(indices);
|
||||
result->types.push_back(aggregateType->getElementType());
|
||||
result->types.push_back(aggregateType.getElementType());
|
||||
}
|
||||
|
||||
void ExtractElementOp::print(OpAsmPrinter *p) const {
|
||||
|
@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << *getAggregate()->getType();
|
||||
*p << " : " << getAggregate()->getType();
|
||||
}
|
||||
|
||||
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType aggregateInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
VectorOrTensorType *type;
|
||||
VectorOrTensorType type;
|
||||
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(aggregateInfo) ||
|
||||
|
@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->parseColonType(type) ||
|
||||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
|
||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type->getElementType(), result->types);
|
||||
parser->addTypeToList(type.getElementType(), result->types);
|
||||
}
|
||||
|
||||
bool ExtractElementOp::verify() const {
|
||||
if (getNumOperands() == 0)
|
||||
return emitOpError("expected an aggregate to index into");
|
||||
|
||||
auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
|
||||
auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
|
||||
if (!aggregateType)
|
||||
return emitOpError("first operand must be a vector or tensor");
|
||||
|
||||
if (getType() != aggregateType->getElementType())
|
||||
if (getType() != aggregateType.getElementType())
|
||||
return emitOpError("result type must match element type of aggregate");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType()->isIndex())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to extract_element must have 'index' type");
|
||||
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
auto aggregateRank = aggregateType->getRank();
|
||||
auto aggregateRank = aggregateType.getRank();
|
||||
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
|
@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const {
|
|||
|
||||
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
|
||||
ArrayRef<SSAValue *> indices) {
|
||||
auto *memrefType = cast<MemRefType>(memref->getType());
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
result->addOperands(memref);
|
||||
result->addOperands(indices);
|
||||
result->types.push_back(memrefType->getElementType());
|
||||
result->types.push_back(memrefType.getElementType());
|
||||
}
|
||||
|
||||
void LoadOp::print(OpAsmPrinter *p) const {
|
||||
|
@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << *getMemRefType();
|
||||
*p << " : " << getMemRefType();
|
||||
}
|
||||
|
||||
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType *type;
|
||||
MemRefType type;
|
||||
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(memrefInfo) ||
|
||||
|
@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->parseColonType(type) ||
|
||||
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type->getElementType(), result->types);
|
||||
parser->addTypeToList(type.getElementType(), result->types);
|
||||
}
|
||||
|
||||
bool LoadOp::verify() const {
|
||||
if (getNumOperands() == 0)
|
||||
return emitOpError("expected a memref to load from");
|
||||
|
||||
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
||||
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return emitOpError("first operand must be a memref");
|
||||
|
||||
if (getType() != memRefType->getElementType())
|
||||
if (getType() != memRefType.getElementType())
|
||||
return emitOpError("result type must match element type of memref");
|
||||
|
||||
if (memRefType->getRank() != getNumOperands() - 1)
|
||||
if (memRefType.getRank() != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of indices for load");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType()->isIndex())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
|
||||
// TODO: Verify we have the right number of indices.
|
||||
|
@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool MemRefCastOp::verify() const {
|
||||
auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
|
||||
auto *resType = dyn_cast<MemRefType>(getType());
|
||||
auto opType = getOperand()->getType().dyn_cast<MemRefType>();
|
||||
auto resType = getType().dyn_cast<MemRefType>();
|
||||
if (!opType || !resType)
|
||||
return emitOpError("requires input and result types to be memrefs");
|
||||
|
||||
if (opType == resType)
|
||||
return emitOpError("requires the input and result type to be different");
|
||||
|
||||
if (opType->getElementType() != resType->getElementType())
|
||||
if (opType.getElementType() != resType.getElementType())
|
||||
return emitOpError(
|
||||
"requires input and result element types to be the same");
|
||||
|
||||
if (opType->getAffineMaps() != resType->getAffineMaps())
|
||||
if (opType.getAffineMaps() != resType.getAffineMaps())
|
||||
return emitOpError("requires input and result mappings to be the same");
|
||||
|
||||
if (opType->getMemorySpace() != resType->getMemorySpace())
|
||||
if (opType.getMemorySpace() != resType.getMemorySpace())
|
||||
return emitOpError(
|
||||
"requires input and result memory spaces to be the same");
|
||||
|
||||
// They must have the same rank, and any specified dimensions must match.
|
||||
if (opType->getRank() != resType->getRank())
|
||||
if (opType.getRank() != resType.getRank())
|
||||
return emitOpError("requires input and result ranks to match");
|
||||
|
||||
for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
|
||||
int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
|
||||
for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
|
||||
int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
|
||||
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
||||
return emitOpError("requires static dimensions to match");
|
||||
}
|
||||
|
@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const {
|
|||
p->printOperands(getIndices());
|
||||
*p << ']';
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
*p << " : " << *getMemRefType();
|
||||
*p << " : " << getMemRefType();
|
||||
}
|
||||
|
||||
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType storeValueInfo;
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType *memrefType;
|
||||
MemRefType memrefType;
|
||||
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
||||
|
@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(memrefType) ||
|
||||
parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
|
||||
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
|
||||
result->operands) ||
|
||||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
||||
|
@ -950,19 +950,19 @@ bool StoreOp::verify() const {
|
|||
return emitOpError("expected a value to store and a memref");
|
||||
|
||||
// Second operand is a memref type.
|
||||
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
||||
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return emitOpError("second operand must be a memref");
|
||||
|
||||
// First operand must have same type as memref element type.
|
||||
if (getValueToStore()->getType() != memRefType->getElementType())
|
||||
if (getValueToStore()->getType() != memRefType.getElementType())
|
||||
return emitOpError("first operand must have same type memref element type");
|
||||
|
||||
if (getNumOperands() != 2 + memRefType->getRank())
|
||||
if (getNumOperands() != 2 + memRefType.getRank())
|
||||
return emitOpError("store index operand count not equal to memref rank");
|
||||
|
||||
for (auto *idx : getIndices())
|
||||
if (!idx->getType()->isIndex())
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
|
||||
// TODO: Verify we have the right number of indices.
|
||||
|
@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool TensorCastOp::verify() const {
|
||||
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
|
||||
auto *resType = dyn_cast<TensorType>(getType());
|
||||
auto opType = getOperand()->getType().dyn_cast<TensorType>();
|
||||
auto resType = getType().dyn_cast<TensorType>();
|
||||
if (!opType || !resType)
|
||||
return emitOpError("requires input and result types to be tensors");
|
||||
|
||||
if (opType == resType)
|
||||
return emitOpError("requires the input and result type to be different");
|
||||
|
||||
if (opType->getElementType() != resType->getElementType())
|
||||
if (opType.getElementType() != resType.getElementType())
|
||||
return emitOpError(
|
||||
"requires input and result element types to be the same");
|
||||
|
||||
// If the source or destination are unranked, then the cast is valid.
|
||||
auto *opRType = dyn_cast<RankedTensorType>(opType);
|
||||
auto *resRType = dyn_cast<RankedTensorType>(resType);
|
||||
auto opRType = opType.dyn_cast<RankedTensorType>();
|
||||
auto resRType = resType.dyn_cast<RankedTensorType>();
|
||||
if (!opRType || !resRType)
|
||||
return false;
|
||||
|
||||
// If they are both ranked, they have to have the same rank, and any specified
|
||||
// dimensions must match.
|
||||
if (opRType->getRank() != resRType->getRank())
|
||||
if (opRType.getRank() != resRType.getRank())
|
||||
return emitOpError("requires input and result ranks to match");
|
||||
|
||||
for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
|
||||
int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
|
||||
for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
|
||||
int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
|
||||
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
||||
return emitOpError("requires static dimensions to match");
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
|||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
// Operation statements that were folded and that need to be erased.
|
||||
std::vector<OperationStmt *> opStmtsToErase;
|
||||
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
|
||||
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
|
||||
|
||||
bool foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
|
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
||||
auto &inst = *instIt++;
|
||||
|
||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||
builder.setInsertionPoint(&inst);
|
||||
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
||||
};
|
||||
|
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
|
||||
// Override the walker's operation statement visit for constant folding.
|
||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||
};
|
||||
|
|
|
@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
||||
|
||||
// Doubles the shape with a leading dimension extent of 2.
|
||||
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
|
||||
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
|
||||
// Add the leading dimension in the shape for the double buffer.
|
||||
ArrayRef<int> shape = oldMemRefType->getShape();
|
||||
ArrayRef<int> shape = oldMemRefType.getShape();
|
||||
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
||||
shapeSizes.insert(shapeSizes.begin(), 2);
|
||||
|
||||
auto *newMemRefType =
|
||||
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
|
||||
oldMemRefType->getMemorySpace());
|
||||
auto newMemRefType =
|
||||
bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
|
||||
oldMemRefType.getMemorySpace());
|
||||
return newMemRefType;
|
||||
};
|
||||
|
||||
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
|
||||
auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
|
||||
|
||||
// Create and place the alloc at the top level.
|
||||
MLFuncBuilder topBuilder(forStmt->getFunction());
|
||||
auto *newMemRef = cast<MLValue>(
|
||||
auto newMemRef = cast<MLValue>(
|
||||
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
|
||||
->getResult());
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ private:
|
|||
/// As part of canonicalization, we move constants to the top of the entry
|
||||
/// block of the current function and de-duplicate them. This keeps track of
|
||||
/// constants we have done this for.
|
||||
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
|
||||
DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
|
||||
};
|
||||
}; // end anonymous namespace
|
||||
|
||||
|
|
|
@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
MLValue *newMemRef,
|
||||
ArrayRef<MLValue *> extraIndices,
|
||||
AffineMap indexRemap) {
|
||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
|
||||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank;
|
||||
if (indexRemap) {
|
||||
assert(indexRemap.getNumInputs() == oldMemRefRank);
|
||||
|
@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
}
|
||||
|
||||
// Assert same elemental type.
|
||||
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
|
||||
cast<MemRefType>(newMemRef->getType())->getElementType());
|
||||
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
|
||||
newMemRef->getType().cast<MemRefType>().getElementType());
|
||||
|
||||
// Check if memref was used in a non-deferencing context.
|
||||
for (const StmtOperand &use : oldMemRef->getUses()) {
|
||||
|
@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
opStmt->operand_end());
|
||||
|
||||
// Result types don't change. Both memref's are of the same elemental type.
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
SmallVector<Type, 8> resultTypes;
|
||||
resultTypes.reserve(opStmt->getNumResults());
|
||||
for (const auto *result : opStmt->getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
|
|
|
@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches,
|
|||
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
|
||||
/// tmpl. The MemRef should be promoted to a closer memory address space in a
|
||||
/// later pass.
|
||||
static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
|
||||
ArrayRef<int> vectorSizes) {
|
||||
auto *elementType = tmpl->getElementType();
|
||||
assert(!dyn_cast<VectorType>(elementType) &&
|
||||
static MemRefType getVectorizedMemRefType(MemRefType tmpl,
|
||||
ArrayRef<int> vectorSizes) {
|
||||
auto elementType = tmpl.getElementType();
|
||||
assert(!elementType.dyn_cast<VectorType>() &&
|
||||
"Can't vectorize an already vector type");
|
||||
assert(tmpl->getAffineMaps().empty() &&
|
||||
assert(tmpl.getAffineMaps().empty() &&
|
||||
"Unsupported non-implicit identity map");
|
||||
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
|
||||
tmpl->getMemorySpace());
|
||||
tmpl.getMemorySpace());
|
||||
}
|
||||
|
||||
/// Creates an unaligned load with the following semantics:
|
||||
|
@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc,
|
|||
operands.insert(operands.end(), dstMemRef);
|
||||
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
||||
using functional::map;
|
||||
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
|
||||
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
|
||||
return v->getType();
|
||||
};
|
||||
auto types = map(getType, operands);
|
||||
|
@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
|
|||
operands.insert(operands.end(), dstMemRef);
|
||||
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
||||
using functional::map;
|
||||
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
|
||||
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
|
||||
return v->getType();
|
||||
};
|
||||
auto types = map(getType, operands);
|
||||
|
@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() {
|
|||
template <typename LoadOrStoreOpPointer>
|
||||
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
|
||||
ArrayRef<int> vectorSize) {
|
||||
auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
|
||||
auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
|
||||
auto memRefType =
|
||||
memoryOp->getMemRef()->getType().template cast<MemRefType>();
|
||||
auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
|
||||
|
||||
// Materialize a MemRef with 1 vector.
|
||||
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
|
||||
|
|
Loading…
Reference in New Issue