forked from OSchip/llvm-project
Implement value type abstraction for types.
This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast. PiperOrigin-RevId: 219372163
This commit is contained in:
parent
75376b8e33
commit
4c465a181d
|
@ -51,7 +51,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
||||||
/// whether indices[dim] is independent of the value `input`.
|
/// whether indices[dim] is independent of the value `input`.
|
||||||
// For now we assume no layout map or identity layout map in the MemRef.
|
// For now we assume no layout map or identity layout map in the MemRef.
|
||||||
// TODO(ntv): support more than identity layout map.
|
// TODO(ntv): support more than identity layout map.
|
||||||
bool isAccessInvariant(const MLValue &input, MemRefType *memRefType,
|
bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
|
||||||
llvm::ArrayRef<MLValue *> indices, unsigned dim);
|
llvm::ArrayRef<MLValue *> indices, unsigned dim);
|
||||||
|
|
||||||
/// Checks whether all the LoadOp and StoreOp matched have access indexing
|
/// Checks whether all the LoadOp and StoreOp matched have access indexing
|
||||||
|
|
|
@ -250,9 +250,9 @@ public:
|
||||||
TypeAttr() = default;
|
TypeAttr() = default;
|
||||||
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
static TypeAttr get(Type *type, MLIRContext *context);
|
static TypeAttr get(Type type, MLIRContext *context);
|
||||||
|
|
||||||
Type *getValue() const;
|
Type getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool kindof(Kind kind) { return kind == Kind::Type; }
|
static bool kindof(Kind kind) { return kind == Kind::Type; }
|
||||||
|
@ -277,7 +277,7 @@ public:
|
||||||
|
|
||||||
Function *getValue() const;
|
Function *getValue() const;
|
||||||
|
|
||||||
FunctionType *getType() const;
|
FunctionType getType() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||||
|
@ -294,7 +294,7 @@ public:
|
||||||
ElementsAttr() = default;
|
ElementsAttr() = default;
|
||||||
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
|
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
VectorOrTensorType *getType() const;
|
VectorOrTensorType getType() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool kindof(Kind kind) {
|
static bool kindof(Kind kind) {
|
||||||
|
@ -313,7 +313,7 @@ public:
|
||||||
SplatElementsAttr() = default;
|
SplatElementsAttr() = default;
|
||||||
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
|
static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
|
||||||
Attribute getValue() const;
|
Attribute getValue() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
|
@ -335,12 +335,12 @@ public:
|
||||||
/// width specified by the element type (note all float type are 64 bits).
|
/// width specified by the element type (note all float type are 64 bits).
|
||||||
/// When the value is retrieved, the bits are read from the storage and extend
|
/// When the value is retrieved, the bits are read from the storage and extend
|
||||||
/// to 64 bits if necessary.
|
/// to 64 bits if necessary.
|
||||||
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
|
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
|
||||||
|
|
||||||
// TODO: Read the data from the attribute list and compress them
|
// TODO: Read the data from the attribute list and compress them
|
||||||
// to a character array. Then call the above method to construct the
|
// to a character array. Then call the above method to construct the
|
||||||
// attribute.
|
// attribute.
|
||||||
static DenseElementsAttr get(VectorOrTensorType *type,
|
static DenseElementsAttr get(VectorOrTensorType type,
|
||||||
ArrayRef<Attribute> values);
|
ArrayRef<Attribute> values);
|
||||||
|
|
||||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
@ -410,7 +410,7 @@ public:
|
||||||
OpaqueElementsAttr() = default;
|
OpaqueElementsAttr() = default;
|
||||||
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
|
static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
|
||||||
|
|
||||||
StringRef getValue() const;
|
StringRef getValue() const;
|
||||||
|
|
||||||
|
@ -440,7 +440,7 @@ public:
|
||||||
SparseElementsAttr() = default;
|
SparseElementsAttr() = default;
|
||||||
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
static SparseElementsAttr get(VectorOrTensorType *type,
|
static SparseElementsAttr get(VectorOrTensorType type,
|
||||||
DenseIntElementsAttr indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr values);
|
DenseElementsAttr values);
|
||||||
|
|
||||||
|
|
|
@ -64,10 +64,10 @@ public:
|
||||||
bool args_empty() const { return arguments.empty(); }
|
bool args_empty() const { return arguments.empty(); }
|
||||||
|
|
||||||
/// Add one value to the operand list.
|
/// Add one value to the operand list.
|
||||||
BBArgument *addArgument(Type *type);
|
BBArgument *addArgument(Type type);
|
||||||
|
|
||||||
/// Add one argument to the argument list for each type specified in the list.
|
/// Add one argument to the argument list for each type specified in the list.
|
||||||
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
|
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
|
||||||
|
|
||||||
unsigned getNumArguments() const { return arguments.size(); }
|
unsigned getNumArguments() const { return arguments.size(); }
|
||||||
BBArgument *getArgument(unsigned i) { return arguments[i]; }
|
BBArgument *getArgument(unsigned i) { return arguments[i]; }
|
||||||
|
|
|
@ -68,29 +68,28 @@ public:
|
||||||
unsigned column);
|
unsigned column);
|
||||||
|
|
||||||
// Types.
|
// Types.
|
||||||
FloatType *getBF16Type();
|
FloatType getBF16Type();
|
||||||
FloatType *getF16Type();
|
FloatType getF16Type();
|
||||||
FloatType *getF32Type();
|
FloatType getF32Type();
|
||||||
FloatType *getF64Type();
|
FloatType getF64Type();
|
||||||
|
|
||||||
OtherType *getIndexType();
|
OtherType getIndexType();
|
||||||
OtherType *getTFControlType();
|
OtherType getTFControlType();
|
||||||
OtherType *getTFStringType();
|
OtherType getTFStringType();
|
||||||
OtherType *getTFResourceType();
|
OtherType getTFResourceType();
|
||||||
OtherType *getTFVariantType();
|
OtherType getTFVariantType();
|
||||||
OtherType *getTFComplex64Type();
|
OtherType getTFComplex64Type();
|
||||||
OtherType *getTFComplex128Type();
|
OtherType getTFComplex128Type();
|
||||||
OtherType *getTFF32REFType();
|
OtherType getTFF32REFType();
|
||||||
|
|
||||||
IntegerType *getIntegerType(unsigned width);
|
IntegerType getIntegerType(unsigned width);
|
||||||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
|
||||||
ArrayRef<Type *> results);
|
MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
|
||||||
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
ArrayRef<AffineMap> affineMapComposition = {},
|
||||||
ArrayRef<AffineMap> affineMapComposition = {},
|
unsigned memorySpace = 0);
|
||||||
unsigned memorySpace = 0);
|
VectorType getVectorType(ArrayRef<int> shape, Type elementType);
|
||||||
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
|
RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
|
||||||
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
UnrankedTensorType getTensorType(Type elementType);
|
||||||
UnrankedTensorType *getTensorType(Type *elementType);
|
|
||||||
|
|
||||||
// Attributes.
|
// Attributes.
|
||||||
|
|
||||||
|
@ -102,15 +101,15 @@ public:
|
||||||
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||||
AffineMapAttr getAffineMapAttr(AffineMap map);
|
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||||
TypeAttr getTypeAttr(Type *type);
|
TypeAttr getTypeAttr(Type type);
|
||||||
FunctionAttr getFunctionAttr(const Function *value);
|
FunctionAttr getFunctionAttr(const Function *value);
|
||||||
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
|
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
|
||||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||||
ArrayRef<char> data);
|
ArrayRef<char> data);
|
||||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
|
||||||
DenseIntElementsAttr indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr values);
|
DenseElementsAttr values);
|
||||||
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
|
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
|
||||||
|
|
||||||
// Affine expressions and affine maps.
|
// Affine expressions and affine maps.
|
||||||
AffineExpr getAffineDimExpr(unsigned position);
|
AffineExpr getAffineDimExpr(unsigned position);
|
||||||
|
@ -366,7 +365,7 @@ public:
|
||||||
/// Creates an operation given the fields.
|
/// Creates an operation given the fields.
|
||||||
OperationStmt *createOperation(Location *location, OperationName name,
|
OperationStmt *createOperation(Location *location, OperationName name,
|
||||||
ArrayRef<MLValue *> operands,
|
ArrayRef<MLValue *> operands,
|
||||||
ArrayRef<Type *> types,
|
ArrayRef<Type> types,
|
||||||
ArrayRef<NamedAttribute> attrs);
|
ArrayRef<NamedAttribute> attrs);
|
||||||
|
|
||||||
/// Create operation of specific op type at the current insertion point.
|
/// Create operation of specific op type at the current insertion point.
|
||||||
|
|
|
@ -96,7 +96,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
|
||||||
public:
|
public:
|
||||||
/// Builds a constant op with the specified attribute value and result type.
|
/// Builds a constant op with the specified attribute value and result type.
|
||||||
static void build(Builder *builder, OperationState *result, Attribute value,
|
static void build(Builder *builder, OperationState *result, Attribute value,
|
||||||
Type *type);
|
Type type);
|
||||||
|
|
||||||
Attribute getValue() const { return getAttr("value"); }
|
Attribute getValue() const { return getAttr("value"); }
|
||||||
|
|
||||||
|
@ -123,7 +123,7 @@ class ConstantFloatOp : public ConstantOp {
|
||||||
public:
|
public:
|
||||||
/// Builds a constant float op producing a float of the specified type.
|
/// Builds a constant float op producing a float of the specified type.
|
||||||
static void build(Builder *builder, OperationState *result,
|
static void build(Builder *builder, OperationState *result,
|
||||||
const APFloat &value, FloatType *type);
|
const APFloat &value, FloatType type);
|
||||||
|
|
||||||
APFloat getValue() const {
|
APFloat getValue() const {
|
||||||
return getAttrOfType<FloatAttr>("value").getValue();
|
return getAttrOfType<FloatAttr>("value").getValue();
|
||||||
|
@ -150,7 +150,7 @@ public:
|
||||||
/// Build a constant int op producing an integer with the specified type,
|
/// Build a constant int op producing an integer with the specified type,
|
||||||
/// which must be an integer type.
|
/// which must be an integer type.
|
||||||
static void build(Builder *builder, OperationState *result, int64_t value,
|
static void build(Builder *builder, OperationState *result, int64_t value,
|
||||||
Type *type);
|
Type type);
|
||||||
|
|
||||||
int64_t getValue() const {
|
int64_t getValue() const {
|
||||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace mlir {
|
||||||
// blocks, each of which includes instructions.
|
// blocks, each of which includes instructions.
|
||||||
class CFGFunction : public Function {
|
class CFGFunction : public Function {
|
||||||
public:
|
public:
|
||||||
CFGFunction(Location *location, StringRef name, FunctionType *type,
|
CFGFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
|
|
||||||
~CFGFunction();
|
~CFGFunction();
|
||||||
|
|
|
@ -66,7 +66,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Basic block arguments are CFG Values.
|
/// Basic block arguments are CFG Values.
|
||||||
|
@ -87,7 +87,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class BasicBlock; // For access to private constructor.
|
friend class BasicBlock; // For access to private constructor.
|
||||||
BBArgument(Type *type, BasicBlock *owner)
|
BBArgument(Type type, BasicBlock *owner)
|
||||||
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
|
: CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
|
||||||
|
|
||||||
/// The owner of this operand.
|
/// The owner of this operand.
|
||||||
|
@ -99,7 +99,7 @@ private:
|
||||||
/// Instruction results are CFG Values.
|
/// Instruction results are CFG Values.
|
||||||
class InstResult : public CFGValue {
|
class InstResult : public CFGValue {
|
||||||
public:
|
public:
|
||||||
InstResult(Type *type, OperationInst *owner)
|
InstResult(Type type, OperationInst *owner)
|
||||||
: CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
|
: CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
|
||||||
|
|
||||||
static bool classof(const SSAValue *value) {
|
static bool classof(const SSAValue *value) {
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
|
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/ilist.h"
|
#include "llvm/ADT/ilist.h"
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ public:
|
||||||
Identifier getName() const { return nameAndKind.getPointer(); }
|
Identifier getName() const { return nameAndKind.getPointer(); }
|
||||||
|
|
||||||
/// Return the type of this function.
|
/// Return the type of this function.
|
||||||
FunctionType *getType() const { return type; }
|
FunctionType getType() const { return type; }
|
||||||
|
|
||||||
/// Returns all of the attributes on this function.
|
/// Returns all of the attributes on this function.
|
||||||
ArrayRef<NamedAttribute> getAttrs() const;
|
ArrayRef<NamedAttribute> getAttrs() const;
|
||||||
|
@ -93,7 +94,7 @@ public:
|
||||||
void emitNote(const Twine &message) const;
|
void emitNote(const Twine &message) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Function(Kind kind, Location *location, StringRef name, FunctionType *type,
|
Function(Kind kind, Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
~Function();
|
~Function();
|
||||||
|
|
||||||
|
@ -108,7 +109,7 @@ private:
|
||||||
Location *location;
|
Location *location;
|
||||||
|
|
||||||
/// The type of the function.
|
/// The type of the function.
|
||||||
FunctionType *const type;
|
FunctionType type;
|
||||||
|
|
||||||
/// This holds general named attributes for the function.
|
/// This holds general named attributes for the function.
|
||||||
AttributeListStorage *attrs;
|
AttributeListStorage *attrs;
|
||||||
|
@ -121,7 +122,7 @@ private:
|
||||||
/// defined in some other module.
|
/// defined in some other module.
|
||||||
class ExtFunction : public Function {
|
class ExtFunction : public Function {
|
||||||
public:
|
public:
|
||||||
ExtFunction(Location *location, StringRef name, FunctionType *type,
|
ExtFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
|
|
|
@ -202,7 +202,7 @@ public:
|
||||||
/// Create a new OperationInst with the specified fields.
|
/// Create a new OperationInst with the specified fields.
|
||||||
static OperationInst *create(Location *location, OperationName name,
|
static OperationInst *create(Location *location, OperationName name,
|
||||||
ArrayRef<CFGValue *> operands,
|
ArrayRef<CFGValue *> operands,
|
||||||
ArrayRef<Type *> resultTypes,
|
ArrayRef<Type> resultTypes,
|
||||||
ArrayRef<NamedAttribute> attributes,
|
ArrayRef<NamedAttribute> attributes,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ class MLFunction final
|
||||||
public:
|
public:
|
||||||
/// Creates a new MLFunction with the specific type.
|
/// Creates a new MLFunction with the specific type.
|
||||||
static MLFunction *create(Location *location, StringRef name,
|
static MLFunction *create(Location *location, StringRef name,
|
||||||
FunctionType *type,
|
FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
|
|
||||||
/// Destroys this statement and its subclass data.
|
/// Destroys this statement and its subclass data.
|
||||||
|
@ -52,7 +52,7 @@ public:
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Returns number of arguments.
|
/// Returns number of arguments.
|
||||||
unsigned getNumArguments() const { return getType()->getInputs().size(); }
|
unsigned getNumArguments() const { return getType().getInputs().size(); }
|
||||||
|
|
||||||
/// Gets argument.
|
/// Gets argument.
|
||||||
MLFuncArgument *getArgument(unsigned idx) {
|
MLFuncArgument *getArgument(unsigned idx) {
|
||||||
|
@ -103,13 +103,13 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MLFunction(Location *location, StringRef name, FunctionType *type,
|
MLFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
|
|
||||||
// This stuff is used by the TrailingObjects template.
|
// This stuff is used by the TrailingObjects template.
|
||||||
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
|
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
|
||||||
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
|
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
|
||||||
return getType()->getInputs().size();
|
return getType().getInputs().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal functions to get argument list used by getArgument() methods.
|
// Internal functions to get argument list used by getArgument() methods.
|
||||||
|
|
|
@ -73,7 +73,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
|
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This is the value defined by an argument of an ML function.
|
/// This is the value defined by an argument of an ML function.
|
||||||
|
@ -93,7 +93,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class MLFunction; // For access to private constructor.
|
friend class MLFunction; // For access to private constructor.
|
||||||
MLFuncArgument(Type *type, MLFunction *owner)
|
MLFuncArgument(Type type, MLFunction *owner)
|
||||||
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
|
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
|
||||||
|
|
||||||
/// The owner of this operand.
|
/// The owner of this operand.
|
||||||
|
@ -105,7 +105,7 @@ private:
|
||||||
/// This is a value defined by a result of an operation instruction.
|
/// This is a value defined by a result of an operation instruction.
|
||||||
class StmtResult : public MLValue {
|
class StmtResult : public MLValue {
|
||||||
public:
|
public:
|
||||||
StmtResult(Type *type, OperationStmt *owner)
|
StmtResult(Type type, OperationStmt *owner)
|
||||||
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
|
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
|
||||||
|
|
||||||
static bool classof(const SSAValue *value) {
|
static bool classof(const SSAValue *value) {
|
||||||
|
|
|
@ -71,13 +71,13 @@ struct constant_int_op_binder {
|
||||||
|
|
||||||
bool match(Operation *op) {
|
bool match(Operation *op) {
|
||||||
if (auto constOp = op->dyn_cast<ConstantOp>()) {
|
if (auto constOp = op->dyn_cast<ConstantOp>()) {
|
||||||
auto *type = constOp->getResult()->getType();
|
auto type = constOp->getResult()->getType();
|
||||||
auto attr = constOp->getAttr("value");
|
auto attr = constOp->getAttr("value");
|
||||||
|
|
||||||
if (isa<IntegerType>(type)) {
|
if (type.isa<IntegerType>()) {
|
||||||
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
||||||
}
|
}
|
||||||
if (isa<VectorOrTensorType>(type)) {
|
if (type.isa<VectorOrTensorType>()) {
|
||||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||||
return attr_value_binder<IntegerAttr>(bind_value)
|
return attr_value_binder<IntegerAttr>(bind_value)
|
||||||
.match(splatAttr.getValue());
|
.match(splatAttr.getValue());
|
||||||
|
|
|
@ -493,7 +493,7 @@ public:
|
||||||
return this->getOperation()->getResult(0);
|
return this->getOperation()->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type *getType() const { return getResult()->getType(); }
|
Type getType() const { return getResult()->getType(); }
|
||||||
|
|
||||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||||
|
@ -539,7 +539,7 @@ public:
|
||||||
return this->getOperation()->getResult(i);
|
return this->getOperation()->getResult(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||||
|
|
||||||
static bool verifyTrait(const Operation *op) {
|
static bool verifyTrait(const Operation *op) {
|
||||||
return impl::verifyNResults(op, N);
|
return impl::verifyNResults(op, N);
|
||||||
|
@ -565,7 +565,7 @@ public:
|
||||||
return this->getOperation()->getResult(i);
|
return this->getOperation()->getResult(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||||
|
|
||||||
static bool verifyTrait(const Operation *op) {
|
static bool verifyTrait(const Operation *op) {
|
||||||
return impl::verifyAtLeastNResults(op, N);
|
return impl::verifyAtLeastNResults(op, N);
|
||||||
|
@ -803,7 +803,7 @@ protected:
|
||||||
// which avoids them being template instantiated/duplicated.
|
// which avoids them being template instantiated/duplicated.
|
||||||
namespace impl {
|
namespace impl {
|
||||||
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
|
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
|
||||||
Type *destType);
|
Type destType);
|
||||||
bool parseCastOp(OpAsmParser *parser, OperationState *result);
|
bool parseCastOp(OpAsmParser *parser, OperationState *result);
|
||||||
void printCastOp(const Operation *op, OpAsmPrinter *p);
|
void printCastOp(const Operation *op, OpAsmPrinter *p);
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
@ -819,7 +819,7 @@ class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
|
||||||
OpTrait::HasNoSideEffect, Traits...> {
|
OpTrait::HasNoSideEffect, Traits...> {
|
||||||
public:
|
public:
|
||||||
static void build(Builder *builder, OperationState *result, SSAValue *source,
|
static void build(Builder *builder, OperationState *result, SSAValue *source,
|
||||||
Type *destType) {
|
Type destType) {
|
||||||
impl::buildCastOp(builder, result, source, destType);
|
impl::buildCastOp(builder, result, source, destType);
|
||||||
}
|
}
|
||||||
static bool parse(OpAsmParser *parser, OperationState *result) {
|
static bool parse(OpAsmParser *parser, OperationState *result) {
|
||||||
|
|
|
@ -67,7 +67,7 @@ public:
|
||||||
printOperand(*it);
|
printOperand(*it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void printType(const Type *type) = 0;
|
virtual void printType(Type type) = 0;
|
||||||
virtual void printFunctionReference(const Function *func) = 0;
|
virtual void printFunctionReference(const Function *func) = 0;
|
||||||
virtual void printAttribute(Attribute attr) = 0;
|
virtual void printAttribute(Attribute attr) = 0;
|
||||||
virtual void printAffineMap(AffineMap map) = 0;
|
virtual void printAffineMap(AffineMap map) = 0;
|
||||||
|
@ -95,8 +95,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
|
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
|
||||||
p.printType(&type);
|
p.printType(type);
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,20 +163,20 @@ public:
|
||||||
virtual bool parseComma() = 0;
|
virtual bool parseComma() = 0;
|
||||||
|
|
||||||
/// Parse a colon followed by a type.
|
/// Parse a colon followed by a type.
|
||||||
virtual bool parseColonType(Type *&result) = 0;
|
virtual bool parseColonType(Type &result) = 0;
|
||||||
|
|
||||||
/// Parse a type of a specific kind, e.g. a FunctionType.
|
/// Parse a type of a specific kind, e.g. a FunctionType.
|
||||||
template <typename TypeType> bool parseColonType(TypeType *&result) {
|
template <typename TypeType> bool parseColonType(TypeType &result) {
|
||||||
llvm::SMLoc loc;
|
llvm::SMLoc loc;
|
||||||
getCurrentLocation(&loc);
|
getCurrentLocation(&loc);
|
||||||
|
|
||||||
// Parse any kind of type.
|
// Parse any kind of type.
|
||||||
Type *type;
|
Type type;
|
||||||
if (parseColonType(type))
|
if (parseColonType(type))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// Check for the right kind of attribute.
|
// Check for the right kind of attribute.
|
||||||
result = dyn_cast<TypeType>(type);
|
result = type.dyn_cast<TypeType>();
|
||||||
if (!result) {
|
if (!result) {
|
||||||
emitError(loc, "invalid kind of type specified");
|
emitError(loc, "invalid kind of type specified");
|
||||||
return true;
|
return true;
|
||||||
|
@ -186,15 +186,15 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a colon followed by a type list, which must have at least one type.
|
/// Parse a colon followed by a type list, which must have at least one type.
|
||||||
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0;
|
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||||
|
|
||||||
/// Parse a keyword followed by a type.
|
/// Parse a keyword followed by a type.
|
||||||
virtual bool parseKeywordType(const char *keyword, Type *&result) = 0;
|
virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
|
||||||
|
|
||||||
/// Add the specified type to the end of the specified type list and return
|
/// Add the specified type to the end of the specified type list and return
|
||||||
/// false. This is a helper designed to allow parse methods to be simple and
|
/// false. This is a helper designed to allow parse methods to be simple and
|
||||||
/// chain through || operators.
|
/// chain through || operators.
|
||||||
bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) {
|
bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
|
||||||
result.push_back(type);
|
result.push_back(type);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -202,7 +202,7 @@ public:
|
||||||
/// Add the specified types to the end of the specified type list and return
|
/// Add the specified types to the end of the specified type list and return
|
||||||
/// false. This is a helper designed to allow parse methods to be simple and
|
/// false. This is a helper designed to allow parse methods to be simple and
|
||||||
/// chain through || operators.
|
/// chain through || operators.
|
||||||
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
|
bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
|
||||||
result.append(types.begin(), types.end());
|
result.append(types.begin(), types.end());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -288,13 +288,13 @@ public:
|
||||||
|
|
||||||
/// Resolve an operand to an SSA value, emitting an error and returning true
|
/// Resolve an operand to an SSA value, emitting an error and returning true
|
||||||
/// on failure.
|
/// on failure.
|
||||||
virtual bool resolveOperand(const OperandType &operand, Type *type,
|
virtual bool resolveOperand(const OperandType &operand, Type type,
|
||||||
SmallVectorImpl<SSAValue *> &result) = 0;
|
SmallVectorImpl<SSAValue *> &result) = 0;
|
||||||
|
|
||||||
/// Resolve a list of operands to SSA values, emitting an error and returning
|
/// Resolve a list of operands to SSA values, emitting an error and returning
|
||||||
/// true on failure, or appending the results to the list on success.
|
/// true on failure, or appending the results to the list on success.
|
||||||
/// This method should be used when all operands have the same type.
|
/// This method should be used when all operands have the same type.
|
||||||
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type,
|
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
|
||||||
SmallVectorImpl<SSAValue *> &result) {
|
SmallVectorImpl<SSAValue *> &result) {
|
||||||
for (auto elt : operands)
|
for (auto elt : operands)
|
||||||
if (resolveOperand(elt, type, result))
|
if (resolveOperand(elt, type, result))
|
||||||
|
@ -306,7 +306,7 @@ public:
|
||||||
/// emitting an error and returning true on failure, or appending the results
|
/// emitting an error and returning true on failure, or appending the results
|
||||||
/// to the list on success.
|
/// to the list on success.
|
||||||
virtual bool resolveOperands(ArrayRef<OperandType> operands,
|
virtual bool resolveOperands(ArrayRef<OperandType> operands,
|
||||||
ArrayRef<Type *> types, llvm::SMLoc loc,
|
ArrayRef<Type> types, llvm::SMLoc loc,
|
||||||
SmallVectorImpl<SSAValue *> &result) {
|
SmallVectorImpl<SSAValue *> &result) {
|
||||||
if (operands.size() != types.size())
|
if (operands.size() != types.size())
|
||||||
return emitError(loc, Twine(operands.size()) +
|
return emitError(loc, Twine(operands.size()) +
|
||||||
|
@ -321,7 +321,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve a parse function name and a type into a function reference.
|
/// Resolve a parse function name and a type into a function reference.
|
||||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
virtual bool resolveFunctionName(StringRef name, FunctionType type,
|
||||||
llvm::SMLoc loc, Function *&result) = 0;
|
llvm::SMLoc loc, Function *&result) = 0;
|
||||||
|
|
||||||
/// Emit a diagnostic at the specified location and return true.
|
/// Emit a diagnostic at the specified location and return true.
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
#include "llvm/ADT/PointerUnion.h"
|
#include "llvm/ADT/PointerUnion.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -191,7 +192,7 @@ struct OperationState {
|
||||||
OperationName name;
|
OperationName name;
|
||||||
SmallVector<SSAValue *, 4> operands;
|
SmallVector<SSAValue *, 4> operands;
|
||||||
/// Types of the results of this operation.
|
/// Types of the results of this operation.
|
||||||
SmallVector<Type *, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
SmallVector<NamedAttribute, 4> attributes;
|
SmallVector<NamedAttribute, 4> attributes;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -202,7 +203,7 @@ public:
|
||||||
: context(context), location(location), name(name) {}
|
: context(context), location(location), name(name) {}
|
||||||
|
|
||||||
OperationState(MLIRContext *context, Location *location, StringRef name,
|
OperationState(MLIRContext *context, Location *location, StringRef name,
|
||||||
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
|
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
|
||||||
ArrayRef<NamedAttribute> attributes = {})
|
ArrayRef<NamedAttribute> attributes = {})
|
||||||
: context(context), location(location), name(name, context),
|
: context(context), location(location), name(name, context),
|
||||||
operands(operands.begin(), operands.end()),
|
operands(operands.begin(), operands.end()),
|
||||||
|
@ -213,7 +214,7 @@ public:
|
||||||
operands.append(newOperands.begin(), newOperands.end());
|
operands.append(newOperands.begin(), newOperands.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addTypes(ArrayRef<Type *> newTypes) {
|
void addTypes(ArrayRef<Type> newTypes) {
|
||||||
types.append(newTypes.begin(), newTypes.end());
|
types.append(newTypes.begin(), newTypes.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/IR/UseDefLists.h"
|
#include "mlir/IR/UseDefLists.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/PointerIntPair.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Function;
|
class Function;
|
||||||
|
@ -51,7 +50,7 @@ public:
|
||||||
|
|
||||||
SSAValueKind getKind() const { return typeAndKind.getInt(); }
|
SSAValueKind getKind() const { return typeAndKind.getInt(); }
|
||||||
|
|
||||||
Type *getType() const { return typeAndKind.getPointer(); }
|
Type getType() const { return typeAndKind.getPointer(); }
|
||||||
|
|
||||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||||
|
@ -93,9 +92,10 @@ public:
|
||||||
void dump() const;
|
void dump() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
|
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind;
|
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
|
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
|
||||||
|
@ -127,7 +127,7 @@ public:
|
||||||
inline use_range getUses() const;
|
inline use_range getUses() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {}
|
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Utility functions for iterating through SSAValue uses.
|
// Utility functions for iterating through SSAValue uses.
|
||||||
|
|
|
@ -44,7 +44,7 @@ public:
|
||||||
/// Create a new OperationStmt with the specific fields.
|
/// Create a new OperationStmt with the specific fields.
|
||||||
static OperationStmt *create(Location *location, OperationName name,
|
static OperationStmt *create(Location *location, OperationName name,
|
||||||
ArrayRef<MLValue *> operands,
|
ArrayRef<MLValue *> operands,
|
||||||
ArrayRef<Type *> resultTypes,
|
ArrayRef<Type> resultTypes,
|
||||||
ArrayRef<NamedAttribute> attributes,
|
ArrayRef<NamedAttribute> attributes,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
|
@ -329,7 +329,7 @@ public:
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Return the context this operation is associated with.
|
/// Return the context this operation is associated with.
|
||||||
MLIRContext *getContext() const { return getType()->getContext(); }
|
MLIRContext *getContext() const { return getType().getContext(); }
|
||||||
|
|
||||||
using Statement::dump;
|
using Statement::dump;
|
||||||
using Statement::print;
|
using Statement::print;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class AffineMap;
|
class AffineMap;
|
||||||
|
@ -28,6 +29,22 @@ class IntegerType;
|
||||||
class FloatType;
|
class FloatType;
|
||||||
class OtherType;
|
class OtherType;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
class TypeStorage;
|
||||||
|
class IntegerTypeStorage;
|
||||||
|
class FloatTypeStorage;
|
||||||
|
struct OtherTypeStorage;
|
||||||
|
struct FunctionTypeStorage;
|
||||||
|
struct VectorOrTensorTypeStorage;
|
||||||
|
struct VectorTypeStorage;
|
||||||
|
struct TensorTypeStorage;
|
||||||
|
struct RankedTensorTypeStorage;
|
||||||
|
struct UnrankedTensorTypeStorage;
|
||||||
|
struct MemRefTypeStorage;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||||
/// MLIRContext. As such, they are passed around by raw non-const pointer.
|
/// MLIRContext. As such, they are passed around by raw non-const pointer.
|
||||||
///
|
///
|
||||||
|
@ -68,11 +85,34 @@ public:
|
||||||
MemRef,
|
MemRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ImplType = detail::TypeStorage;
|
||||||
|
|
||||||
|
Type() : type(nullptr) {}
|
||||||
|
/* implicit */ Type(const ImplType *type)
|
||||||
|
: type(const_cast<ImplType *>(type)) {}
|
||||||
|
|
||||||
|
Type(const Type &other) : type(other.type) {}
|
||||||
|
Type &operator=(Type other) {
|
||||||
|
type = other.type;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(Type other) const { return type == other.type; }
|
||||||
|
bool operator!=(Type other) const { return !(*this == other); }
|
||||||
|
explicit operator bool() const { return type; }
|
||||||
|
|
||||||
|
bool operator!() const { return type == nullptr; }
|
||||||
|
|
||||||
|
template <typename U> bool isa() const;
|
||||||
|
template <typename U> U dyn_cast() const;
|
||||||
|
template <typename U> U dyn_cast_or_null() const;
|
||||||
|
template <typename U> U cast() const;
|
||||||
|
|
||||||
/// Return the classification for this type.
|
/// Return the classification for this type.
|
||||||
Kind getKind() const { return kind; }
|
Kind getKind() const;
|
||||||
|
|
||||||
/// Return the LLVMContext in which this type was uniqued.
|
/// Return the LLVMContext in which this type was uniqued.
|
||||||
MLIRContext *getContext() const { return context; }
|
MLIRContext *getContext() const;
|
||||||
|
|
||||||
// Convenience predicates. This is only for 'other' and floating point types,
|
// Convenience predicates. This is only for 'other' and floating point types,
|
||||||
// derived types should use isa/dyn_cast.
|
// derived types should use isa/dyn_cast.
|
||||||
|
@ -97,56 +137,42 @@ public:
|
||||||
unsigned getBitWidth() const;
|
unsigned getBitWidth() const;
|
||||||
|
|
||||||
// Convenience factories.
|
// Convenience factories.
|
||||||
static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
|
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
|
||||||
static FloatType *getBF16(MLIRContext *ctx);
|
static FloatType getBF16(MLIRContext *ctx);
|
||||||
static FloatType *getF16(MLIRContext *ctx);
|
static FloatType getF16(MLIRContext *ctx);
|
||||||
static FloatType *getF32(MLIRContext *ctx);
|
static FloatType getF32(MLIRContext *ctx);
|
||||||
static FloatType *getF64(MLIRContext *ctx);
|
static FloatType getF64(MLIRContext *ctx);
|
||||||
static OtherType *getIndex(MLIRContext *ctx);
|
static OtherType getIndex(MLIRContext *ctx);
|
||||||
static OtherType *getTFControl(MLIRContext *ctx);
|
static OtherType getTFControl(MLIRContext *ctx);
|
||||||
static OtherType *getTFString(MLIRContext *ctx);
|
static OtherType getTFString(MLIRContext *ctx);
|
||||||
static OtherType *getTFResource(MLIRContext *ctx);
|
static OtherType getTFResource(MLIRContext *ctx);
|
||||||
static OtherType *getTFVariant(MLIRContext *ctx);
|
static OtherType getTFVariant(MLIRContext *ctx);
|
||||||
static OtherType *getTFComplex64(MLIRContext *ctx);
|
static OtherType getTFComplex64(MLIRContext *ctx);
|
||||||
static OtherType *getTFComplex128(MLIRContext *ctx);
|
static OtherType getTFComplex128(MLIRContext *ctx);
|
||||||
static OtherType *getTFF32REF(MLIRContext *ctx);
|
static OtherType getTFF32REF(MLIRContext *ctx);
|
||||||
|
|
||||||
/// Print the current type.
|
/// Print the current type.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
void dump() const;
|
void dump() const;
|
||||||
|
|
||||||
|
friend ::llvm::hash_code hash_value(Type arg);
|
||||||
|
|
||||||
|
unsigned getSubclassData() const;
|
||||||
|
void setSubclassData(unsigned val);
|
||||||
|
|
||||||
|
/// Methods for supporting PointerLikeTypeTraits.
|
||||||
|
const void *getAsOpaquePointer() const {
|
||||||
|
return static_cast<const void *>(type);
|
||||||
|
}
|
||||||
|
static Type getFromOpaquePointer(const void *pointer) {
|
||||||
|
return Type((ImplType *)(pointer));
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit Type(Kind kind, MLIRContext *context)
|
ImplType *type;
|
||||||
: context(context), kind(kind), subclassData(0) {}
|
|
||||||
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
|
|
||||||
: Type(kind, context) {
|
|
||||||
setSubclassData(subClassData);
|
|
||||||
}
|
|
||||||
|
|
||||||
~Type() {}
|
|
||||||
|
|
||||||
unsigned getSubclassData() const { return subclassData; }
|
|
||||||
|
|
||||||
void setSubclassData(unsigned val) {
|
|
||||||
subclassData = val;
|
|
||||||
// Ensure we don't have any accidental truncation.
|
|
||||||
assert(getSubclassData() == val && "Subclass data too large for field");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Type(const Type&) = delete;
|
|
||||||
void operator=(const Type&) = delete;
|
|
||||||
/// This refers to the MLIRContext in which this type was uniqued.
|
|
||||||
MLIRContext *const context;
|
|
||||||
|
|
||||||
/// Classification of the subclass, used for type checking.
|
|
||||||
Kind kind : 8;
|
|
||||||
|
|
||||||
// Space for subclasses to store data.
|
|
||||||
unsigned subclassData : 24;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
|
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
|
||||||
type.print(os);
|
type.print(os);
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -154,148 +180,138 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
|
||||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||||
class IntegerType : public Type {
|
class IntegerType : public Type {
|
||||||
public:
|
public:
|
||||||
static IntegerType *get(unsigned width, MLIRContext *context);
|
using ImplType = detail::IntegerTypeStorage;
|
||||||
|
IntegerType() = default;
|
||||||
|
/* implicit */ IntegerType(Type::ImplType *ptr);
|
||||||
|
|
||||||
|
static IntegerType get(unsigned width, MLIRContext *context);
|
||||||
|
|
||||||
/// Return the bitwidth of this integer type.
|
/// Return the bitwidth of this integer type.
|
||||||
unsigned getWidth() const {
|
unsigned getWidth() const;
|
||||||
return width;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||||
return type->getKind() == Kind::Integer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Integer representation maximal bitwidth.
|
/// Integer representation maximal bitwidth.
|
||||||
static constexpr unsigned kMaxWidth = 4096;
|
static constexpr unsigned kMaxWidth = 4096;
|
||||||
private:
|
|
||||||
unsigned width;
|
|
||||||
IntegerType(unsigned width, MLIRContext *context);
|
|
||||||
~IntegerType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
|
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||||
return IntegerType::get(width, ctx);
|
return IntegerType::get(width, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return true if this is an integer type with the specified width.
|
/// Return true if this is an integer type with the specified width.
|
||||||
inline bool Type::isInteger(unsigned width) const {
|
inline bool Type::isInteger(unsigned width) const {
|
||||||
if (auto *intTy = dyn_cast<IntegerType>(this))
|
if (auto intTy = dyn_cast<IntegerType>())
|
||||||
return intTy->getWidth() == width;
|
return intTy.getWidth() == width;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
class FloatType : public Type {
|
class FloatType : public Type {
|
||||||
public:
|
public:
|
||||||
|
using ImplType = detail::FloatTypeStorage;
|
||||||
|
FloatType() = default;
|
||||||
|
/* implicit */ FloatType(Type::ImplType *ptr);
|
||||||
|
|
||||||
|
static FloatType get(Kind kind, MLIRContext *context);
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) {
|
||||||
return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||||
type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE;
|
kind <= Kind::LAST_FLOATING_POINT_TYPE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static FloatType *get(Kind kind, MLIRContext *context);
|
|
||||||
|
|
||||||
private:
|
|
||||||
FloatType(Kind kind, MLIRContext *context);
|
|
||||||
~FloatType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline FloatType *Type::getBF16(MLIRContext *ctx) {
|
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||||
return FloatType::get(Kind::BF16, ctx);
|
return FloatType::get(Kind::BF16, ctx);
|
||||||
}
|
}
|
||||||
inline FloatType *Type::getF16(MLIRContext *ctx) {
|
inline FloatType Type::getF16(MLIRContext *ctx) {
|
||||||
return FloatType::get(Kind::F16, ctx);
|
return FloatType::get(Kind::F16, ctx);
|
||||||
}
|
}
|
||||||
inline FloatType *Type::getF32(MLIRContext *ctx) {
|
inline FloatType Type::getF32(MLIRContext *ctx) {
|
||||||
return FloatType::get(Kind::F32, ctx);
|
return FloatType::get(Kind::F32, ctx);
|
||||||
}
|
}
|
||||||
inline FloatType *Type::getF64(MLIRContext *ctx) {
|
inline FloatType Type::getF64(MLIRContext *ctx) {
|
||||||
return FloatType::get(Kind::F64, ctx);
|
return FloatType::get(Kind::F64, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This is a type for the random collection of special base types.
|
/// This is a type for the random collection of special base types.
|
||||||
class OtherType : public Type {
|
class OtherType : public Type {
|
||||||
public:
|
public:
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
using ImplType = detail::OtherTypeStorage;
|
||||||
static bool classof(const Type *type) {
|
OtherType() = default;
|
||||||
return type->getKind() >= Kind::FIRST_OTHER_TYPE &&
|
/* implicit */ OtherType(Type::ImplType *ptr);
|
||||||
type->getKind() <= Kind::LAST_OTHER_TYPE;
|
|
||||||
}
|
|
||||||
static OtherType *get(Kind kind, MLIRContext *context);
|
|
||||||
|
|
||||||
private:
|
static OtherType get(Kind kind, MLIRContext *context);
|
||||||
OtherType(Kind kind, MLIRContext *context);
|
|
||||||
~OtherType() = delete;
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
|
static bool kindof(Kind kind) {
|
||||||
|
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline OtherType *Type::getIndex(MLIRContext *ctx) {
|
inline OtherType Type::getIndex(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::Index, ctx);
|
return OtherType::get(Kind::Index, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
|
inline OtherType Type::getTFControl(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFControl, ctx);
|
return OtherType::get(Kind::TFControl, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFResource(MLIRContext *ctx) {
|
inline OtherType Type::getTFResource(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFResource, ctx);
|
return OtherType::get(Kind::TFResource, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
inline OtherType Type::getTFString(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFString, ctx);
|
return OtherType::get(Kind::TFString, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
|
inline OtherType Type::getTFVariant(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFVariant, ctx);
|
return OtherType::get(Kind::TFVariant, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFComplex64(MLIRContext *ctx) {
|
inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFComplex64, ctx);
|
return OtherType::get(Kind::TFComplex64, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFComplex128(MLIRContext *ctx) {
|
inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFComplex128, ctx);
|
return OtherType::get(Kind::TFComplex128, ctx);
|
||||||
}
|
}
|
||||||
inline OtherType *Type::getTFF32REF(MLIRContext *ctx) {
|
inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFF32REF, ctx);
|
return OtherType::get(Kind::TFF32REF, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Function types map from a list of inputs to a list of results.
|
/// Function types map from a list of inputs to a list of results.
|
||||||
class FunctionType : public Type {
|
class FunctionType : public Type {
|
||||||
public:
|
public:
|
||||||
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
|
using ImplType = detail::FunctionTypeStorage;
|
||||||
MLIRContext *context);
|
FunctionType() = default;
|
||||||
|
/* implicit */ FunctionType(Type::ImplType *ptr);
|
||||||
|
|
||||||
|
static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
|
||||||
|
MLIRContext *context);
|
||||||
|
|
||||||
// Input types.
|
// Input types.
|
||||||
unsigned getNumInputs() const { return getSubclassData(); }
|
unsigned getNumInputs() const { return getSubclassData(); }
|
||||||
|
|
||||||
Type *getInput(unsigned i) const { return getInputs()[i]; }
|
Type getInput(unsigned i) const { return getInputs()[i]; }
|
||||||
|
|
||||||
ArrayRef<Type*> getInputs() const {
|
ArrayRef<Type> getInputs() const;
|
||||||
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result types.
|
// Result types.
|
||||||
unsigned getNumResults() const { return numResults; }
|
unsigned getNumResults() const;
|
||||||
|
|
||||||
Type *getResult(unsigned i) const { return getResults()[i]; }
|
Type getResult(unsigned i) const { return getResults()[i]; }
|
||||||
|
|
||||||
ArrayRef<Type*> getResults() const {
|
ArrayRef<Type> getResults() const;
|
||||||
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||||
return type->getKind() == Kind::Function;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
unsigned numResults;
|
|
||||||
Type *const *inputsAndResults;
|
|
||||||
|
|
||||||
FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
|
||||||
unsigned numResults, MLIRContext *context);
|
|
||||||
~FunctionType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||||
/// types, because many operations work on values of these aggregate types.
|
/// types, because many operations work on values of these aggregate types.
|
||||||
class VectorOrTensorType : public Type {
|
class VectorOrTensorType : public Type {
|
||||||
public:
|
public:
|
||||||
Type *getElementType() const { return elementType; }
|
using ImplType = detail::VectorOrTensorTypeStorage;
|
||||||
|
VectorOrTensorType() = default;
|
||||||
|
/* implicit */ VectorOrTensorType(Type::ImplType *ptr);
|
||||||
|
|
||||||
|
Type getElementType() const;
|
||||||
|
|
||||||
/// If this is ranked tensor or vector type, return the number of elements. If
|
/// If this is ranked tensor or vector type, return the number of elements. If
|
||||||
/// it is an unranked tensor or vector, abort.
|
/// it is an unranked tensor or vector, abort.
|
||||||
|
@ -319,56 +335,40 @@ public:
|
||||||
int getDimSize(unsigned i) const;
|
int getDimSize(unsigned i) const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) {
|
||||||
return type->getKind() == Kind::Vector ||
|
return kind == Kind::Vector || kind == Kind::RankedTensor ||
|
||||||
type->getKind() == Kind::RankedTensor ||
|
kind == Kind::UnrankedTensor;
|
||||||
type->getKind() == Kind::UnrankedTensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
|
||||||
Type *elementType;
|
|
||||||
|
|
||||||
VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
|
|
||||||
unsigned subClassData = 0);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||||
/// known constant shape with one or more dimension.
|
/// known constant shape with one or more dimension.
|
||||||
class VectorType : public VectorOrTensorType {
|
class VectorType : public VectorOrTensorType {
|
||||||
public:
|
public:
|
||||||
static VectorType *get(ArrayRef<int> shape, Type *elementType);
|
using ImplType = detail::VectorTypeStorage;
|
||||||
|
VectorType() = default;
|
||||||
|
/* implicit */ VectorType(Type::ImplType *ptr);
|
||||||
|
|
||||||
ArrayRef<int> getShape() const {
|
static VectorType get(ArrayRef<int> shape, Type elementType);
|
||||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
|
||||||
}
|
ArrayRef<int> getShape() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) { return kind == Kind::Vector; }
|
||||||
return type->getKind() == Kind::Vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const int *shapeElements;
|
|
||||||
Type *elementType;
|
|
||||||
|
|
||||||
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
|
|
||||||
~VectorType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||||
/// RankedTensorType and UnrankedTensorType.
|
/// RankedTensorType and UnrankedTensorType.
|
||||||
class TensorType : public VectorOrTensorType {
|
class TensorType : public VectorOrTensorType {
|
||||||
public:
|
public:
|
||||||
|
using ImplType = detail::TensorTypeStorage;
|
||||||
|
TensorType() = default;
|
||||||
|
/* implicit */ TensorType(Type::ImplType *ptr);
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) {
|
||||||
return type->getKind() == Kind::RankedTensor ||
|
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
|
||||||
type->getKind() == Kind::UnrankedTensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
TensorType(Kind kind, Type *elementType, MLIRContext *context);
|
|
||||||
~TensorType() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||||
|
@ -376,40 +376,30 @@ protected:
|
||||||
/// integer or unknown (represented -1).
|
/// integer or unknown (represented -1).
|
||||||
class RankedTensorType : public TensorType {
|
class RankedTensorType : public TensorType {
|
||||||
public:
|
public:
|
||||||
static RankedTensorType *get(ArrayRef<int> shape,
|
using ImplType = detail::RankedTensorTypeStorage;
|
||||||
Type *elementType);
|
RankedTensorType() = default;
|
||||||
|
/* implicit */ RankedTensorType(Type::ImplType *ptr);
|
||||||
|
|
||||||
ArrayRef<int> getShape() const {
|
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
|
||||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool classof(const Type *type) {
|
ArrayRef<int> getShape() const;
|
||||||
return type->getKind() == Kind::RankedTensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
|
||||||
const int *shapeElements;
|
|
||||||
|
|
||||||
RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
|
||||||
MLIRContext *context);
|
|
||||||
~RankedTensorType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||||
/// unknown shape.
|
/// unknown shape.
|
||||||
class UnrankedTensorType : public TensorType {
|
class UnrankedTensorType : public TensorType {
|
||||||
public:
|
public:
|
||||||
static UnrankedTensorType *get(Type *elementType);
|
using ImplType = detail::UnrankedTensorTypeStorage;
|
||||||
|
UnrankedTensorType() = default;
|
||||||
|
/* implicit */ UnrankedTensorType(Type::ImplType *ptr);
|
||||||
|
|
||||||
|
static UnrankedTensorType get(Type elementType);
|
||||||
|
|
||||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||||
|
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
|
||||||
return type->getKind() == Kind::UnrankedTensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
UnrankedTensorType(Type *elementType, MLIRContext *context);
|
|
||||||
~UnrankedTensorType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||||
|
@ -418,62 +408,96 @@ private:
|
||||||
/// affine map composition, represented as an array AffineMap pointers.
|
/// affine map composition, represented as an array AffineMap pointers.
|
||||||
class MemRefType : public Type {
|
class MemRefType : public Type {
|
||||||
public:
|
public:
|
||||||
|
using ImplType = detail::MemRefTypeStorage;
|
||||||
|
MemRefType() = default;
|
||||||
|
/* implicit */ MemRefType(Type::ImplType *ptr);
|
||||||
|
|
||||||
/// Get or create a new MemRefType based on shape, element type, affine
|
/// Get or create a new MemRefType based on shape, element type, affine
|
||||||
/// map composition, and memory space.
|
/// map composition, and memory space.
|
||||||
static MemRefType *get(ArrayRef<int> shape, Type *elementType,
|
static MemRefType get(ArrayRef<int> shape, Type elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition,
|
ArrayRef<AffineMap> affineMapComposition,
|
||||||
unsigned memorySpace);
|
unsigned memorySpace);
|
||||||
|
|
||||||
unsigned getRank() const { return getShape().size(); }
|
unsigned getRank() const { return getShape().size(); }
|
||||||
|
|
||||||
/// Returns an array of memref shape dimension sizes.
|
/// Returns an array of memref shape dimension sizes.
|
||||||
ArrayRef<int> getShape() const {
|
ArrayRef<int> getShape() const;
|
||||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
/// Return the size of the specified dimension, or -1 if unspecified.
|
||||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
int getDimSize(unsigned i) const { return getShape()[i]; }
|
||||||
|
|
||||||
/// Returns the elemental type for this memref shape.
|
/// Returns the elemental type for this memref shape.
|
||||||
Type *getElementType() const { return elementType; }
|
Type getElementType() const;
|
||||||
|
|
||||||
/// Returns an array of affine map pointers representing the memref affine
|
/// Returns an array of affine map pointers representing the memref affine
|
||||||
/// map composition.
|
/// map composition.
|
||||||
ArrayRef<AffineMap> getAffineMaps() const;
|
ArrayRef<AffineMap> getAffineMaps() const;
|
||||||
|
|
||||||
/// Returns the memory space in which data referred to by this memref resides.
|
/// Returns the memory space in which data referred to by this memref resides.
|
||||||
unsigned getMemorySpace() const { return memorySpace; }
|
unsigned getMemorySpace() const;
|
||||||
|
|
||||||
/// Returns the number of dimensions with dynamic size.
|
/// Returns the number of dimensions with dynamic size.
|
||||||
unsigned getNumDynamicDims() const;
|
unsigned getNumDynamicDims() const;
|
||||||
|
|
||||||
static bool classof(const Type *type) {
|
static bool kindof(Kind kind) { return kind == Kind::MemRef; }
|
||||||
return type->getKind() == Kind::MemRef;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// The type of each scalar element of the memref.
|
|
||||||
Type *elementType;
|
|
||||||
/// An array of integers which stores the shape dimension sizes.
|
|
||||||
const int *shapeElements;
|
|
||||||
/// The number of affine maps in the 'affineMapList' array.
|
|
||||||
const unsigned numAffineMaps;
|
|
||||||
/// List of affine maps in the memref's layout/index map composition.
|
|
||||||
AffineMap const *affineMapList;
|
|
||||||
/// Memory space in which data referenced by memref resides.
|
|
||||||
const unsigned memorySpace;
|
|
||||||
|
|
||||||
MemRefType(ArrayRef<int> shape, Type *elementType,
|
|
||||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
|
||||||
MLIRContext *context);
|
|
||||||
~MemRefType() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return true if the specified element type is ok in a tensor.
|
// Make Type hashable.
|
||||||
static bool isValidTensorElementType(Type *type) {
|
inline ::llvm::hash_code hash_value(Type arg) {
|
||||||
return isa<FloatType>(type) || isa<VectorType>(type) ||
|
return ::llvm::hash_value(arg.type);
|
||||||
isa<IntegerType>(type) || isa<OtherType>(type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U> bool Type::isa() const {
|
||||||
|
assert(type && "isa<> used on a null type.");
|
||||||
|
return U::kindof(getKind());
|
||||||
|
}
|
||||||
|
template <typename U> U Type::dyn_cast() const {
|
||||||
|
return isa<U>() ? U(type) : U(nullptr);
|
||||||
|
}
|
||||||
|
template <typename U> U Type::dyn_cast_or_null() const {
|
||||||
|
return (type && isa<U>()) ? U(type) : U(nullptr);
|
||||||
|
}
|
||||||
|
template <typename U> U Type::cast() const {
|
||||||
|
assert(isa<U>());
|
||||||
|
return U(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if the specified element type is ok in a tensor.
|
||||||
|
static bool isValidTensorElementType(Type type) {
|
||||||
|
return type.isa<FloatType>() || type.isa<VectorType>() ||
|
||||||
|
type.isa<IntegerType>() || type.isa<OtherType>();
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
|
||||||
|
// Type hash just like pointers.
|
||||||
|
template <> struct DenseMapInfo<mlir::Type> {
|
||||||
|
static mlir::Type getEmptyKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
|
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static mlir::Type getTombstoneKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
|
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
|
||||||
|
static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
|
||||||
|
template <> struct PointerLikeTypeTraits<mlir::Type> {
|
||||||
|
public:
|
||||||
|
static inline void *getAsVoidPointer(mlir::Type I) {
|
||||||
|
return const_cast<void *>(I.getAsOpaquePointer());
|
||||||
|
}
|
||||||
|
static inline mlir::Type getFromVoidPointer(void *P) {
|
||||||
|
return mlir::Type::getFromOpaquePointer(P);
|
||||||
|
}
|
||||||
|
enum { NumLowBitsAvailable = 3 };
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
#endif // MLIR_IR_TYPES_H
|
#endif // MLIR_IR_TYPES_H
|
||||||
|
|
|
@ -104,15 +104,15 @@ class AllocOp
|
||||||
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
: public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
||||||
public:
|
public:
|
||||||
/// The result of an alloc is always a MemRefType.
|
/// The result of an alloc is always a MemRefType.
|
||||||
MemRefType *getType() const {
|
MemRefType getType() const {
|
||||||
return cast<MemRefType>(getResult()->getType());
|
return getResult()->getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static StringRef getOperationName() { return "alloc"; }
|
static StringRef getOperationName() { return "alloc"; }
|
||||||
|
|
||||||
// Hooks to customize behavior of this op.
|
// Hooks to customize behavior of this op.
|
||||||
static void build(Builder *builder, OperationState *result,
|
static void build(Builder *builder, OperationState *result,
|
||||||
MemRefType *memrefType, ArrayRef<SSAValue *> operands = {});
|
MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||||
void print(OpAsmPrinter *p) const;
|
void print(OpAsmPrinter *p) const;
|
||||||
|
@ -276,7 +276,7 @@ public:
|
||||||
const SSAValue *getSrcMemRef() const { return getOperand(0); }
|
const SSAValue *getSrcMemRef() const { return getOperand(0); }
|
||||||
// Returns the rank (number of indices) of the source MemRefType.
|
// Returns the rank (number of indices) of the source MemRefType.
|
||||||
unsigned getSrcMemRefRank() const {
|
unsigned getSrcMemRefRank() const {
|
||||||
return cast<MemRefType>(getSrcMemRef()->getType())->getRank();
|
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
}
|
}
|
||||||
// Returns the source memerf indices for this DMA operation.
|
// Returns the source memerf indices for this DMA operation.
|
||||||
llvm::iterator_range<Operation::const_operand_iterator>
|
llvm::iterator_range<Operation::const_operand_iterator>
|
||||||
|
@ -291,13 +291,13 @@ public:
|
||||||
}
|
}
|
||||||
// Returns the rank (number of indices) of the destination MemRefType.
|
// Returns the rank (number of indices) of the destination MemRefType.
|
||||||
unsigned getDstMemRefRank() const {
|
unsigned getDstMemRefRank() const {
|
||||||
return cast<MemRefType>(getDstMemRef()->getType())->getRank();
|
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
}
|
}
|
||||||
unsigned getSrcMemorySpace() const {
|
unsigned getSrcMemorySpace() const {
|
||||||
return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace();
|
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||||
}
|
}
|
||||||
unsigned getDstMemorySpace() const {
|
unsigned getDstMemorySpace() const {
|
||||||
return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace();
|
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the destination memref indices for this DMA operation.
|
// Returns the destination memref indices for this DMA operation.
|
||||||
|
@ -387,7 +387,7 @@ public:
|
||||||
|
|
||||||
// Returns the rank (number of indices) of the tag memref.
|
// Returns the rank (number of indices) of the tag memref.
|
||||||
unsigned getTagMemRefRank() const {
|
unsigned getTagMemRefRank() const {
|
||||||
return cast<MemRefType>(getTagMemRef()->getType())->getRank();
|
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the number of elements transferred in the associated DMA operation.
|
// Returns the number of elements transferred in the associated DMA operation.
|
||||||
|
@ -460,8 +460,8 @@ public:
|
||||||
SSAValue *getMemRef() { return getOperand(0); }
|
SSAValue *getMemRef() { return getOperand(0); }
|
||||||
const SSAValue *getMemRef() const { return getOperand(0); }
|
const SSAValue *getMemRef() const { return getOperand(0); }
|
||||||
void setMemRef(SSAValue *value) { setOperand(0, value); }
|
void setMemRef(SSAValue *value) { setOperand(0, value); }
|
||||||
MemRefType *getMemRefType() const {
|
MemRefType getMemRefType() const {
|
||||||
return cast<MemRefType>(getMemRef()->getType());
|
return getMemRef()->getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
||||||
|
@ -508,8 +508,8 @@ public:
|
||||||
static StringRef getOperationName() { return "memref_cast"; }
|
static StringRef getOperationName() { return "memref_cast"; }
|
||||||
|
|
||||||
/// The result of a memref_cast is always a memref.
|
/// The result of a memref_cast is always a memref.
|
||||||
MemRefType *getType() const {
|
MemRefType getType() const {
|
||||||
return cast<MemRefType>(getResult()->getType());
|
return getResult()->getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
|
@ -583,8 +583,8 @@ public:
|
||||||
SSAValue *getMemRef() { return getOperand(1); }
|
SSAValue *getMemRef() { return getOperand(1); }
|
||||||
const SSAValue *getMemRef() const { return getOperand(1); }
|
const SSAValue *getMemRef() const { return getOperand(1); }
|
||||||
void setMemRef(SSAValue *value) { setOperand(1, value); }
|
void setMemRef(SSAValue *value) { setOperand(1, value); }
|
||||||
MemRefType *getMemRefType() const {
|
MemRefType getMemRefType() const {
|
||||||
return cast<MemRefType>(getMemRef()->getType());
|
return getMemRef()->getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
||||||
|
@ -671,8 +671,8 @@ public:
|
||||||
static StringRef getOperationName() { return "tensor_cast"; }
|
static StringRef getOperationName() { return "tensor_cast"; }
|
||||||
|
|
||||||
/// The result of a tensor_cast is always a tensor.
|
/// The result of a tensor_cast is always a tensor.
|
||||||
TensorType *getType() const {
|
TensorType getType() const {
|
||||||
return cast<TensorType>(getResult()->getType());
|
return getResult()->getType().cast<TensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
|
|
|
@ -118,15 +118,15 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
||||||
return tripCountExpr.getLargestKnownDivisor();
|
return tripCountExpr.getLargestKnownDivisor();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
|
bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
|
||||||
ArrayRef<MLValue *> indices, unsigned dim) {
|
ArrayRef<MLValue *> indices, unsigned dim) {
|
||||||
assert(indices.size() == memRefType->getRank());
|
assert(indices.size() == memRefType.getRank());
|
||||||
assert(dim < indices.size());
|
assert(dim < indices.size());
|
||||||
auto layoutMap = memRefType->getAffineMaps();
|
auto layoutMap = memRefType.getAffineMaps();
|
||||||
assert(memRefType->getAffineMaps().size() <= 1);
|
assert(memRefType.getAffineMaps().size() <= 1);
|
||||||
// TODO(ntv): remove dependency on Builder once we support non-identity
|
// TODO(ntv): remove dependency on Builder once we support non-identity
|
||||||
// layout map.
|
// layout map.
|
||||||
Builder b(memRefType->getContext());
|
Builder b(memRefType.getContext());
|
||||||
assert(layoutMap.empty() ||
|
assert(layoutMap.empty() ||
|
||||||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
|
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
|
||||||
(void)layoutMap;
|
(void)layoutMap;
|
||||||
|
@ -170,7 +170,7 @@ static bool isContiguousAccess(const MLValue &input,
|
||||||
using namespace functional;
|
using namespace functional;
|
||||||
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
|
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
|
||||||
memoryOp->getIndices());
|
memoryOp->getIndices());
|
||||||
auto *memRefType = memoryOp->getMemRefType();
|
auto memRefType = memoryOp->getMemRefType();
|
||||||
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
|
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
|
||||||
if (fastestVaryingDim == (numIndices - 1) - d) {
|
if (fastestVaryingDim == (numIndices - 1) - d) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -184,8 +184,8 @@ static bool isContiguousAccess(const MLValue &input,
|
||||||
|
|
||||||
template <typename LoadOrStoreOpPointer>
|
template <typename LoadOrStoreOpPointer>
|
||||||
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
|
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
|
||||||
auto *memRefType = memoryOp->getMemRefType();
|
auto memRefType = memoryOp->getMemRefType();
|
||||||
return isa<VectorType>(memRefType->getElementType());
|
return memRefType.getElementType().template isa<VectorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
|
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
|
||||||
|
|
|
@ -195,7 +195,7 @@ bool CFGFuncVerifier::verify() {
|
||||||
|
|
||||||
// Verify that the argument list of the function and the arg list of the first
|
// Verify that the argument list of the function and the arg list of the first
|
||||||
// block line up.
|
// block line up.
|
||||||
auto fnInputTypes = fn.getType()->getInputs();
|
auto fnInputTypes = fn.getType().getInputs();
|
||||||
if (fnInputTypes.size() != firstBB->getNumArguments())
|
if (fnInputTypes.size() != firstBB->getNumArguments())
|
||||||
return failure("first block of cfgfunc must have " +
|
return failure("first block of cfgfunc must have " +
|
||||||
Twine(fnInputTypes.size()) +
|
Twine(fnInputTypes.size()) +
|
||||||
|
@ -306,7 +306,7 @@ bool CFGFuncVerifier::verifyBBArguments(ArrayRef<InstOperand> operands,
|
||||||
|
|
||||||
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
|
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
|
||||||
// Verify that the return operands match the results of the function.
|
// Verify that the return operands match the results of the function.
|
||||||
auto results = fn.getType()->getResults();
|
auto results = fn.getType().getResults();
|
||||||
if (inst.getNumOperands() != results.size())
|
if (inst.getNumOperands() != results.size())
|
||||||
return failure("return has " + Twine(inst.getNumOperands()) +
|
return failure("return has " + Twine(inst.getNumOperands()) +
|
||||||
" operands, but enclosing function returns " +
|
" operands, but enclosing function returns " +
|
||||||
|
|
|
@ -122,7 +122,7 @@ private:
|
||||||
void visitForStmt(const ForStmt *forStmt);
|
void visitForStmt(const ForStmt *forStmt);
|
||||||
void visitIfStmt(const IfStmt *ifStmt);
|
void visitIfStmt(const IfStmt *ifStmt);
|
||||||
void visitOperationStmt(const OperationStmt *opStmt);
|
void visitOperationStmt(const OperationStmt *opStmt);
|
||||||
void visitType(const Type *type);
|
void visitType(Type type);
|
||||||
void visitAttribute(Attribute attr);
|
void visitAttribute(Attribute attr);
|
||||||
void visitOperation(const Operation *op);
|
void visitOperation(const Operation *op);
|
||||||
|
|
||||||
|
@ -135,16 +135,16 @@ private:
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
// TODO Support visiting other types/instructions when implemented.
|
// TODO Support visiting other types/instructions when implemented.
|
||||||
void ModuleState::visitType(const Type *type) {
|
void ModuleState::visitType(Type type) {
|
||||||
if (auto *funcType = dyn_cast<FunctionType>(type)) {
|
if (auto funcType = type.dyn_cast<FunctionType>()) {
|
||||||
// Visit input and result types for functions.
|
// Visit input and result types for functions.
|
||||||
for (auto *input : funcType->getInputs())
|
for (auto input : funcType.getInputs())
|
||||||
visitType(input);
|
visitType(input);
|
||||||
for (auto *result : funcType->getResults())
|
for (auto result : funcType.getResults())
|
||||||
visitType(result);
|
visitType(result);
|
||||||
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
|
} else if (auto memref = type.dyn_cast<MemRefType>()) {
|
||||||
// Visit affine maps in memref type.
|
// Visit affine maps in memref type.
|
||||||
for (auto map : memref->getAffineMaps()) {
|
for (auto map : memref.getAffineMaps()) {
|
||||||
recordAffineMapReference(map);
|
recordAffineMapReference(map);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ public:
|
||||||
void print(const Module *module);
|
void print(const Module *module);
|
||||||
void printFunctionReference(const Function *func);
|
void printFunctionReference(const Function *func);
|
||||||
void printAttribute(Attribute attr);
|
void printAttribute(Attribute attr);
|
||||||
void printType(const Type *type);
|
void printType(Type type);
|
||||||
void print(const Function *fn);
|
void print(const Function *fn);
|
||||||
void print(const ExtFunction *fn);
|
void print(const ExtFunction *fn);
|
||||||
void print(const CFGFunction *fn);
|
void print(const CFGFunction *fn);
|
||||||
|
@ -290,7 +290,7 @@ protected:
|
||||||
void printFunctionAttributes(const Function *fn);
|
void printFunctionAttributes(const Function *fn);
|
||||||
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
||||||
ArrayRef<const char *> elidedAttrs = {});
|
ArrayRef<const char *> elidedAttrs = {});
|
||||||
void printFunctionResultType(const FunctionType *type);
|
void printFunctionResultType(FunctionType type);
|
||||||
void printAffineMapId(int affineMapId) const;
|
void printAffineMapId(int affineMapId) const;
|
||||||
void printAffineMapReference(AffineMap affineMap);
|
void printAffineMapReference(AffineMap affineMap);
|
||||||
void printIntegerSetId(int integerSetId) const;
|
void printIntegerSetId(int integerSetId) const;
|
||||||
|
@ -489,9 +489,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||||
auto *type = attr.getType();
|
auto type = attr.getType();
|
||||||
auto shape = type->getShape();
|
auto shape = type.getShape();
|
||||||
auto rank = type->getRank();
|
auto rank = type.getRank();
|
||||||
|
|
||||||
SmallVector<Attribute, 16> elements;
|
SmallVector<Attribute, 16> elements;
|
||||||
attr.getValues(elements);
|
attr.getValues(elements);
|
||||||
|
@ -541,8 +541,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||||
os << ']';
|
os << ']';
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printType(const Type *type) {
|
void ModulePrinter::printType(Type type) {
|
||||||
switch (type->getKind()) {
|
switch (type.getKind()) {
|
||||||
case Type::Kind::Index:
|
case Type::Kind::Index:
|
||||||
os << "index";
|
os << "index";
|
||||||
return;
|
return;
|
||||||
|
@ -581,71 +581,71 @@ void ModulePrinter::printType(const Type *type) {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
case Type::Kind::Integer: {
|
case Type::Kind::Integer: {
|
||||||
auto *integer = cast<IntegerType>(type);
|
auto integer = type.cast<IntegerType>();
|
||||||
os << 'i' << integer->getWidth();
|
os << 'i' << integer.getWidth();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type::Kind::Function: {
|
case Type::Kind::Function: {
|
||||||
auto *func = cast<FunctionType>(type);
|
auto func = type.cast<FunctionType>();
|
||||||
os << '(';
|
os << '(';
|
||||||
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
|
interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
|
||||||
os << ") -> ";
|
os << ") -> ";
|
||||||
auto results = func->getResults();
|
auto results = func.getResults();
|
||||||
if (results.size() == 1)
|
if (results.size() == 1)
|
||||||
os << *results[0];
|
os << results[0];
|
||||||
else {
|
else {
|
||||||
os << '(';
|
os << '(';
|
||||||
interleaveComma(results, [&](Type *type) { printType(type); });
|
interleaveComma(results, [&](Type type) { printType(type); });
|
||||||
os << ')';
|
os << ')';
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type::Kind::Vector: {
|
case Type::Kind::Vector: {
|
||||||
auto *v = cast<VectorType>(type);
|
auto v = type.cast<VectorType>();
|
||||||
os << "vector<";
|
os << "vector<";
|
||||||
for (auto dim : v->getShape())
|
for (auto dim : v.getShape())
|
||||||
os << dim << 'x';
|
os << dim << 'x';
|
||||||
os << *v->getElementType() << '>';
|
os << v.getElementType() << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type::Kind::RankedTensor: {
|
case Type::Kind::RankedTensor: {
|
||||||
auto *v = cast<RankedTensorType>(type);
|
auto v = type.cast<RankedTensorType>();
|
||||||
os << "tensor<";
|
os << "tensor<";
|
||||||
for (auto dim : v->getShape()) {
|
for (auto dim : v.getShape()) {
|
||||||
if (dim < 0)
|
if (dim < 0)
|
||||||
os << '?';
|
os << '?';
|
||||||
else
|
else
|
||||||
os << dim;
|
os << dim;
|
||||||
os << 'x';
|
os << 'x';
|
||||||
}
|
}
|
||||||
os << *v->getElementType() << '>';
|
os << v.getElementType() << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type::Kind::UnrankedTensor: {
|
case Type::Kind::UnrankedTensor: {
|
||||||
auto *v = cast<UnrankedTensorType>(type);
|
auto v = type.cast<UnrankedTensorType>();
|
||||||
os << "tensor<*x";
|
os << "tensor<*x";
|
||||||
printType(v->getElementType());
|
printType(v.getElementType());
|
||||||
os << '>';
|
os << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case Type::Kind::MemRef: {
|
case Type::Kind::MemRef: {
|
||||||
auto *v = cast<MemRefType>(type);
|
auto v = type.cast<MemRefType>();
|
||||||
os << "memref<";
|
os << "memref<";
|
||||||
for (auto dim : v->getShape()) {
|
for (auto dim : v.getShape()) {
|
||||||
if (dim < 0)
|
if (dim < 0)
|
||||||
os << '?';
|
os << '?';
|
||||||
else
|
else
|
||||||
os << dim;
|
os << dim;
|
||||||
os << 'x';
|
os << 'x';
|
||||||
}
|
}
|
||||||
printType(v->getElementType());
|
printType(v.getElementType());
|
||||||
for (auto map : v->getAffineMaps()) {
|
for (auto map : v.getAffineMaps()) {
|
||||||
os << ", ";
|
os << ", ";
|
||||||
printAffineMapReference(map);
|
printAffineMapReference(map);
|
||||||
}
|
}
|
||||||
// Only print the memory space if it is the non-default one.
|
// Only print the memory space if it is the non-default one.
|
||||||
if (v->getMemorySpace())
|
if (v.getMemorySpace())
|
||||||
os << ", " << v->getMemorySpace();
|
os << ", " << v.getMemorySpace();
|
||||||
os << '>';
|
os << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -842,18 +842,18 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
|
||||||
// Function printing
|
// Function printing
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ModulePrinter::printFunctionResultType(const FunctionType *type) {
|
void ModulePrinter::printFunctionResultType(FunctionType type) {
|
||||||
switch (type->getResults().size()) {
|
switch (type.getResults().size()) {
|
||||||
case 0:
|
case 0:
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
os << " -> ";
|
os << " -> ";
|
||||||
printType(type->getResults()[0]);
|
printType(type.getResults()[0]);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
os << " -> (";
|
os << " -> (";
|
||||||
interleaveComma(type->getResults(),
|
interleaveComma(type.getResults(),
|
||||||
[&](Type *eltType) { printType(eltType); });
|
[&](Type eltType) { printType(eltType); });
|
||||||
os << ')';
|
os << ')';
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -871,8 +871,7 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
|
||||||
auto type = fn->getType();
|
auto type = fn->getType();
|
||||||
|
|
||||||
os << "@" << fn->getName() << '(';
|
os << "@" << fn->getName() << '(';
|
||||||
interleaveComma(type->getInputs(),
|
interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
|
||||||
[&](Type *eltType) { printType(eltType); });
|
|
||||||
os << ')';
|
os << ')';
|
||||||
|
|
||||||
printFunctionResultType(type);
|
printFunctionResultType(type);
|
||||||
|
@ -937,7 +936,7 @@ public:
|
||||||
|
|
||||||
// Implement OpAsmPrinter.
|
// Implement OpAsmPrinter.
|
||||||
raw_ostream &getStream() const { return os; }
|
raw_ostream &getStream() const { return os; }
|
||||||
void printType(const Type *type) { ModulePrinter::printType(type); }
|
void printType(Type type) { ModulePrinter::printType(type); }
|
||||||
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
|
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
|
||||||
void printAffineMap(AffineMap map) {
|
void printAffineMap(AffineMap map) {
|
||||||
return ModulePrinter::printAffineMapReference(map);
|
return ModulePrinter::printAffineMapReference(map);
|
||||||
|
@ -974,10 +973,10 @@ protected:
|
||||||
if (auto *op = value->getDefiningOperation()) {
|
if (auto *op = value->getDefiningOperation()) {
|
||||||
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
|
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
|
||||||
// i1 constants get special names.
|
// i1 constants get special names.
|
||||||
if (intOp->getType()->isInteger(1)) {
|
if (intOp->getType().isInteger(1)) {
|
||||||
specialName << (intOp->getValue() ? "true" : "false");
|
specialName << (intOp->getValue() ? "true" : "false");
|
||||||
} else {
|
} else {
|
||||||
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
|
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
|
||||||
}
|
}
|
||||||
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
||||||
specialName << 'c' << intOp->getValue();
|
specialName << 'c' << intOp->getValue();
|
||||||
|
@ -1579,7 +1578,7 @@ void Attribute::dump() const { print(llvm::errs()); }
|
||||||
|
|
||||||
void Type::print(raw_ostream &os) const {
|
void Type::print(raw_ostream &os) const {
|
||||||
ModuleState state(getContext());
|
ModuleState state(getContext());
|
||||||
ModulePrinter(os, state).printType(this);
|
ModulePrinter(os, state).printType(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Type::dump() const { print(llvm::errs()); }
|
void Type::dump() const { print(llvm::errs()); }
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
#include "llvm/Support/TrailingObjects.h"
|
#include "llvm/Support/TrailingObjects.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -86,7 +87,7 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
|
||||||
|
|
||||||
/// An attribute representing a reference to a type.
|
/// An attribute representing a reference to a type.
|
||||||
struct TypeAttributeStorage : public AttributeStorage {
|
struct TypeAttributeStorage : public AttributeStorage {
|
||||||
Type *value;
|
Type value;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute representing a reference to a function.
|
/// An attribute representing a reference to a function.
|
||||||
|
@ -96,7 +97,7 @@ struct FunctionAttributeStorage : public AttributeStorage {
|
||||||
|
|
||||||
/// A base attribute representing a reference to a vector or tensor constant.
|
/// A base attribute representing a reference to a vector or tensor constant.
|
||||||
struct ElementsAttributeStorage : public AttributeStorage {
|
struct ElementsAttributeStorage : public AttributeStorage {
|
||||||
VectorOrTensorType *type;
|
VectorOrTensorType type;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute representing a reference to a vector or tensor constant,
|
/// An attribute representing a reference to a vector or tensor constant,
|
||||||
|
|
|
@ -75,9 +75,7 @@ IntegerSet IntegerSetAttr::getValue() const {
|
||||||
|
|
||||||
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
Type *TypeAttr::getValue() const {
|
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||||
return static_cast<ImplType *>(attr)->value;
|
|
||||||
}
|
|
||||||
|
|
||||||
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
@ -85,11 +83,11 @@ Function *FunctionAttr::getValue() const {
|
||||||
return static_cast<ImplType *>(attr)->value;
|
return static_cast<ImplType *>(attr)->value;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
||||||
|
|
||||||
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
VectorOrTensorType *ElementsAttr::getType() const {
|
VectorOrTensorType ElementsAttr::getType() const {
|
||||||
return static_cast<ImplType *>(attr)->type;
|
return static_cast<ImplType *>(attr)->type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,8 +164,8 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||||
|
|
||||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||||
auto elementNum = getType()->getNumElements();
|
auto elementNum = getType().getNumElements();
|
||||||
auto context = getType()->getContext();
|
auto context = getType().getContext();
|
||||||
values.reserve(elementNum);
|
values.reserve(elementNum);
|
||||||
if (bitsWidth == 64) {
|
if (bitsWidth == 64) {
|
||||||
ArrayRef<int64_t> vs(
|
ArrayRef<int64_t> vs(
|
||||||
|
@ -192,8 +190,8 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
|
||||||
: DenseElementsAttr(ptr) {}
|
: DenseElementsAttr(ptr) {}
|
||||||
|
|
||||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
auto elementNum = getType()->getNumElements();
|
auto elementNum = getType().getNumElements();
|
||||||
auto context = getType()->getContext();
|
auto context = getType().getContext();
|
||||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||||
getRawData().size() / 8});
|
getRawData().size() / 8});
|
||||||
values.reserve(elementNum);
|
values.reserve(elementNum);
|
||||||
|
|
|
@ -33,18 +33,18 @@ BasicBlock::~BasicBlock() {
|
||||||
// Argument list management.
|
// Argument list management.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
BBArgument *BasicBlock::addArgument(Type *type) {
|
BBArgument *BasicBlock::addArgument(Type type) {
|
||||||
auto *arg = new BBArgument(type, this);
|
auto *arg = new BBArgument(type, this);
|
||||||
arguments.push_back(arg);
|
arguments.push_back(arg);
|
||||||
return arg;
|
return arg;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add one argument to the argument list for each type specified in the list.
|
/// Add one argument to the argument list for each type specified in the list.
|
||||||
auto BasicBlock::addArguments(ArrayRef<Type *> types)
|
auto BasicBlock::addArguments(ArrayRef<Type> types)
|
||||||
-> llvm::iterator_range<args_iterator> {
|
-> llvm::iterator_range<args_iterator> {
|
||||||
arguments.reserve(arguments.size() + types.size());
|
arguments.reserve(arguments.size() + types.size());
|
||||||
auto initialSize = arguments.size();
|
auto initialSize = arguments.size();
|
||||||
for (auto *type : types) {
|
for (auto type : types) {
|
||||||
addArgument(type);
|
addArgument(type);
|
||||||
}
|
}
|
||||||
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
|
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
|
||||||
|
|
|
@ -52,59 +52,58 @@ FileLineColLoc *Builder::getFileLineColLoc(UniquedFilename filename,
|
||||||
// Types.
|
// Types.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
|
FloatType Builder::getBF16Type() { return Type::getBF16(context); }
|
||||||
|
|
||||||
FloatType *Builder::getF16Type() { return Type::getF16(context); }
|
FloatType Builder::getF16Type() { return Type::getF16(context); }
|
||||||
|
|
||||||
FloatType *Builder::getF32Type() { return Type::getF32(context); }
|
FloatType Builder::getF32Type() { return Type::getF32(context); }
|
||||||
|
|
||||||
FloatType *Builder::getF64Type() { return Type::getF64(context); }
|
FloatType Builder::getF64Type() { return Type::getF64(context); }
|
||||||
|
|
||||||
OtherType *Builder::getIndexType() { return Type::getIndex(context); }
|
OtherType Builder::getIndexType() { return Type::getIndex(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
|
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFComplex64Type() {
|
OtherType Builder::getTFComplex64Type() {
|
||||||
return Type::getTFComplex64(context);
|
return Type::getTFComplex64(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
OtherType *Builder::getTFComplex128Type() {
|
OtherType Builder::getTFComplex128Type() {
|
||||||
return Type::getTFComplex128(context);
|
return Type::getTFComplex128(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
|
||||||
|
|
||||||
IntegerType *Builder::getIntegerType(unsigned width) {
|
IntegerType Builder::getIntegerType(unsigned width) {
|
||||||
return Type::getInteger(width, context);
|
return Type::getInteger(width, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
|
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
|
||||||
ArrayRef<Type *> results) {
|
ArrayRef<Type> results) {
|
||||||
return FunctionType::get(inputs, results, context);
|
return FunctionType::get(inputs, results, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
|
MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition,
|
ArrayRef<AffineMap> affineMapComposition,
|
||||||
unsigned memorySpace) {
|
unsigned memorySpace) {
|
||||||
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
|
VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
|
||||||
return VectorType::get(shape, elementType);
|
return VectorType::get(shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
|
RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
|
||||||
Type *elementType) {
|
|
||||||
return RankedTensorType::get(shape, elementType);
|
return RankedTensorType::get(shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedTensorType *Builder::getTensorType(Type *elementType) {
|
UnrankedTensorType Builder::getTensorType(Type elementType) {
|
||||||
return UnrankedTensorType::get(elementType);
|
return UnrankedTensorType::get(elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +143,7 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
|
||||||
return IntegerSetAttr::get(set);
|
return IntegerSetAttr::get(set);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeAttr Builder::getTypeAttr(Type *type) {
|
TypeAttr Builder::getTypeAttr(Type type) {
|
||||||
return TypeAttr::get(type, context);
|
return TypeAttr::get(type, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,23 +151,23 @@ FunctionAttr Builder::getFunctionAttr(const Function *value) {
|
||||||
return FunctionAttr::get(value, context);
|
return FunctionAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
|
||||||
Attribute elt) {
|
Attribute elt) {
|
||||||
return SplatElementsAttr::get(type, elt);
|
return SplatElementsAttr::get(type, elt);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
||||||
ArrayRef<char> data) {
|
ArrayRef<char> data) {
|
||||||
return DenseElementsAttr::get(type, data);
|
return DenseElementsAttr::get(type, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
|
||||||
DenseIntElementsAttr indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr values) {
|
DenseElementsAttr values) {
|
||||||
return SparseElementsAttr::get(type, indices, values);
|
return SparseElementsAttr::get(type, indices, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
|
||||||
StringRef bytes) {
|
StringRef bytes) {
|
||||||
return OpaqueElementsAttr::get(type, bytes);
|
return OpaqueElementsAttr::get(type, bytes);
|
||||||
}
|
}
|
||||||
|
@ -296,7 +295,7 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
||||||
OperationStmt *MLFuncBuilder::createOperation(Location *location,
|
OperationStmt *MLFuncBuilder::createOperation(Location *location,
|
||||||
OperationName name,
|
OperationName name,
|
||||||
ArrayRef<MLValue *> operands,
|
ArrayRef<MLValue *> operands,
|
||||||
ArrayRef<Type *> types,
|
ArrayRef<Type> types,
|
||||||
ArrayRef<NamedAttribute> attrs) {
|
ArrayRef<NamedAttribute> attrs) {
|
||||||
auto *op = OperationStmt::create(location, name, operands, types, attrs,
|
auto *op = OperationStmt::create(location, name, operands, types, attrs,
|
||||||
getContext());
|
getContext());
|
||||||
|
|
|
@ -63,7 +63,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
|
||||||
numDims = opInfos.size();
|
numDims = opInfos.size();
|
||||||
|
|
||||||
// Parse the optional symbol operands.
|
// Parse the optional symbol operands.
|
||||||
auto *affineIntTy = parser->getBuilder().getIndexType();
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||||
if (parser->parseOperandList(opInfos, -1,
|
if (parser->parseOperandList(opInfos, -1,
|
||||||
OpAsmParser::Delimiter::OptionalSquare) ||
|
OpAsmParser::Delimiter::OptionalSquare) ||
|
||||||
parser->resolveOperands(opInfos, affineIntTy, operands))
|
parser->resolveOperands(opInfos, affineIntTy, operands))
|
||||||
|
@ -84,7 +84,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
|
||||||
|
|
||||||
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
auto &builder = parser->getBuilder();
|
auto &builder = parser->getBuilder();
|
||||||
auto *affineIntTy = builder.getIndexType();
|
auto affineIntTy = builder.getIndexType();
|
||||||
|
|
||||||
AffineMapAttr mapAttr;
|
AffineMapAttr mapAttr;
|
||||||
unsigned numDims;
|
unsigned numDims;
|
||||||
|
@ -171,7 +171,7 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
||||||
|
|
||||||
/// Builds a constant op with the specified attribute value and result type.
|
/// Builds a constant op with the specified attribute value and result type.
|
||||||
void ConstantOp::build(Builder *builder, OperationState *result,
|
void ConstantOp::build(Builder *builder, OperationState *result,
|
||||||
Attribute value, Type *type) {
|
Attribute value, Type type) {
|
||||||
result->addAttribute("value", value);
|
result->addAttribute("value", value);
|
||||||
result->types.push_back(type);
|
result->types.push_back(type);
|
||||||
}
|
}
|
||||||
|
@ -181,12 +181,12 @@ void ConstantOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
||||||
|
|
||||||
if (!getValue().isa<FunctionAttr>())
|
if (!getValue().isa<FunctionAttr>())
|
||||||
*p << " : " << *getType();
|
*p << " : " << getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
Attribute valueAttr;
|
Attribute valueAttr;
|
||||||
Type *type;
|
Type type;
|
||||||
|
|
||||||
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
||||||
parser->parseOptionalAttributeDict(result->attributes))
|
parser->parseOptionalAttributeDict(result->attributes))
|
||||||
|
@ -208,33 +208,33 @@ bool ConstantOp::verify() const {
|
||||||
if (!value)
|
if (!value)
|
||||||
return emitOpError("requires a 'value' attribute");
|
return emitOpError("requires a 'value' attribute");
|
||||||
|
|
||||||
auto *type = this->getType();
|
auto type = this->getType();
|
||||||
if (isa<IntegerType>(type) || type->isIndex()) {
|
if (type.isa<IntegerType>() || type.isIndex()) {
|
||||||
if (!value.isa<IntegerAttr>())
|
if (!value.isa<IntegerAttr>())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"requires 'value' to be an integer for an integer result type");
|
"requires 'value' to be an integer for an integer result type");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<FloatType>(type)) {
|
if (type.isa<FloatType>()) {
|
||||||
if (!value.isa<FloatAttr>())
|
if (!value.isa<FloatAttr>())
|
||||||
return emitOpError("requires 'value' to be a floating point constant");
|
return emitOpError("requires 'value' to be a floating point constant");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<VectorOrTensorType>(type)) {
|
if (type.isa<VectorOrTensorType>()) {
|
||||||
if (!value.isa<ElementsAttr>())
|
if (!value.isa<ElementsAttr>())
|
||||||
return emitOpError("requires 'value' to be a vector/tensor constant");
|
return emitOpError("requires 'value' to be a vector/tensor constant");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type->isTFString()) {
|
if (type.isTFString()) {
|
||||||
if (!value.isa<StringAttr>())
|
if (!value.isa<StringAttr>())
|
||||||
return emitOpError("requires 'value' to be a string constant");
|
return emitOpError("requires 'value' to be a string constant");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<FunctionType>(type)) {
|
if (type.isa<FunctionType>()) {
|
||||||
if (!value.isa<FunctionAttr>())
|
if (!value.isa<FunctionAttr>())
|
||||||
return emitOpError("requires 'value' to be a function reference");
|
return emitOpError("requires 'value' to be a function reference");
|
||||||
return false;
|
return false;
|
||||||
|
@ -251,19 +251,19 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
||||||
const APFloat &value, FloatType *type) {
|
const APFloat &value, FloatType type) {
|
||||||
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFloatOp::isClassFor(const Operation *op) {
|
bool ConstantFloatOp::isClassFor(const Operation *op) {
|
||||||
return ConstantOp::isClassFor(op) &&
|
return ConstantOp::isClassFor(op) &&
|
||||||
isa<FloatType>(op->getResult(0)->getType());
|
op->getResult(0)->getType().isa<FloatType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
||||||
bool ConstantIntOp::isClassFor(const Operation *op) {
|
bool ConstantIntOp::isClassFor(const Operation *op) {
|
||||||
return ConstantOp::isClassFor(op) &&
|
return ConstantOp::isClassFor(op) &&
|
||||||
isa<IntegerType>(op->getResult(0)->getType());
|
op->getResult(0)->getType().isa<IntegerType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||||
|
@ -275,14 +275,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||||
/// Build a constant int op producing an integer with the specified type,
|
/// Build a constant int op producing an integer with the specified type,
|
||||||
/// which must be an integer type.
|
/// which must be an integer type.
|
||||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||||
int64_t value, Type *type) {
|
int64_t value, Type type) {
|
||||||
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
|
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
|
||||||
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
|
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ConstantIndexOp only matches values whose result type is Index.
|
/// ConstantIndexOp only matches values whose result type is Index.
|
||||||
bool ConstantIndexOp::isClassFor(const Operation *op) {
|
bool ConstantIndexOp::isClassFor(const Operation *op) {
|
||||||
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
|
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
||||||
|
@ -302,7 +302,7 @@ void ReturnOp::build(Builder *builder, OperationState *result,
|
||||||
|
|
||||||
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||||
SmallVector<Type *, 2> types;
|
SmallVector<Type, 2> types;
|
||||||
llvm::SMLoc loc;
|
llvm::SMLoc loc;
|
||||||
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
|
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
|
||||||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
||||||
|
@ -330,7 +330,7 @@ bool ReturnOp::verify() const {
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
MLFunction *function = cast<MLFunction>(block);
|
MLFunction *function = cast<MLFunction>(block);
|
||||||
const auto &results = function->getType()->getResults();
|
const auto &results = function->getType().getResults();
|
||||||
if (stmt->getNumOperands() != results.size())
|
if (stmt->getNumOperands() != results.size())
|
||||||
return emitOpError("has " + Twine(stmt->getNumOperands()) +
|
return emitOpError("has " + Twine(stmt->getNumOperands()) +
|
||||||
" operands, but enclosing function returns " +
|
" operands, but enclosing function returns " +
|
||||||
|
|
|
@ -28,8 +28,8 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
Function::Function(Kind kind, Location *location, StringRef name,
|
Function::Function(Kind kind, Location *location, StringRef name,
|
||||||
FunctionType *type, ArrayRef<NamedAttribute> attrs)
|
FunctionType type, ArrayRef<NamedAttribute> attrs)
|
||||||
: nameAndKind(Identifier::get(name, type->getContext()), kind),
|
: nameAndKind(Identifier::get(name, type.getContext()), kind),
|
||||||
location(location), type(type) {
|
location(location), type(type) {
|
||||||
this->attrs = AttributeListStorage::get(attrs, getContext());
|
this->attrs = AttributeListStorage::get(attrs, getContext());
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ ArrayRef<NamedAttribute> Function::getAttrs() const {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext *Function::getContext() const { return getType()->getContext(); }
|
MLIRContext *Function::getContext() const { return getType().getContext(); }
|
||||||
|
|
||||||
/// Delete this object.
|
/// Delete this object.
|
||||||
void Function::destroy() {
|
void Function::destroy() {
|
||||||
|
@ -159,7 +159,7 @@ void Function::emitError(const Twine &message) const {
|
||||||
// ExtFunction implementation.
|
// ExtFunction implementation.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
|
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs)
|
ArrayRef<NamedAttribute> attrs)
|
||||||
: Function(Kind::ExtFunc, location, name, type, attrs) {}
|
: Function(Kind::ExtFunc, location, name, type, attrs) {}
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
|
||||||
// CFGFunction implementation.
|
// CFGFunction implementation.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
|
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs)
|
ArrayRef<NamedAttribute> attrs)
|
||||||
: Function(Kind::CFGFunc, location, name, type, attrs) {}
|
: Function(Kind::CFGFunc, location, name, type, attrs) {}
|
||||||
|
|
||||||
|
@ -188,9 +188,9 @@ CFGFunction::~CFGFunction() {
|
||||||
|
|
||||||
/// Create a new MLFunction with the specific fields.
|
/// Create a new MLFunction with the specific fields.
|
||||||
MLFunction *MLFunction::create(Location *location, StringRef name,
|
MLFunction *MLFunction::create(Location *location, StringRef name,
|
||||||
FunctionType *type,
|
FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs) {
|
ArrayRef<NamedAttribute> attrs) {
|
||||||
const auto &argTypes = type->getInputs();
|
const auto &argTypes = type.getInputs();
|
||||||
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
|
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
|
||||||
void *rawMem = malloc(byteSize);
|
void *rawMem = malloc(byteSize);
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
|
MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs)
|
ArrayRef<NamedAttribute> attrs)
|
||||||
: Function(Kind::MLFunc, location, name, type, attrs),
|
: Function(Kind::MLFunc, location, name, type, attrs),
|
||||||
StmtBlock(StmtBlockKind::MLFunc) {}
|
StmtBlock(StmtBlockKind::MLFunc) {}
|
||||||
|
|
|
@ -143,7 +143,7 @@ void Instruction::emitError(const Twine &message) const {
|
||||||
/// Create a new OperationInst with the specified fields.
|
/// Create a new OperationInst with the specified fields.
|
||||||
OperationInst *OperationInst::create(Location *location, OperationName name,
|
OperationInst *OperationInst::create(Location *location, OperationName name,
|
||||||
ArrayRef<CFGValue *> operands,
|
ArrayRef<CFGValue *> operands,
|
||||||
ArrayRef<Type *> resultTypes,
|
ArrayRef<Type> resultTypes,
|
||||||
ArrayRef<NamedAttribute> attributes,
|
ArrayRef<NamedAttribute> attributes,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
|
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
|
||||||
|
@ -167,7 +167,7 @@ OperationInst *OperationInst::create(Location *location, OperationName name,
|
||||||
|
|
||||||
OperationInst *OperationInst::clone() const {
|
OperationInst *OperationInst::clone() const {
|
||||||
SmallVector<CFGValue *, 8> operands;
|
SmallVector<CFGValue *, 8> operands;
|
||||||
SmallVector<Type *, 8> resultTypes;
|
SmallVector<Type, 8> resultTypes;
|
||||||
|
|
||||||
// Put together the operands and results.
|
// Put together the operands and results.
|
||||||
for (auto *operand : getOperands())
|
for (auto *operand : getOperands())
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "AttributeDetail.h"
|
#include "AttributeDetail.h"
|
||||||
#include "AttributeListStorage.h"
|
#include "AttributeListStorage.h"
|
||||||
#include "IntegerSetDetail.h"
|
#include "IntegerSetDetail.h"
|
||||||
|
#include "TypeDetail.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -44,11 +45,11 @@ using namespace mlir::detail;
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
|
||||||
// Functions are uniqued based on their inputs and results.
|
// Functions are uniqued based on their inputs and results.
|
||||||
using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
|
using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
|
||||||
using DenseMapInfo<FunctionType *>::getHashValue;
|
using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<FunctionType *>::isEqual;
|
using DenseMapInfo<FunctionTypeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
|
@ -56,7 +57,7 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
||||||
hash_combine_range(key.second.begin(), key.second.end()));
|
hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
|
static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
|
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
|
||||||
|
@ -109,65 +110,64 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
|
struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
|
||||||
// Vectors are uniqued based on their element type and shape.
|
// Vectors are uniqued based on their element type and shape.
|
||||||
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
using KeyTy = std::pair<Type, ArrayRef<int>>;
|
||||||
using DenseMapInfo<VectorType *>::getHashValue;
|
using DenseMapInfo<VectorTypeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<VectorType *>::isEqual;
|
using DenseMapInfo<VectorTypeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
DenseMapInfo<Type *>::getHashValue(key.first),
|
DenseMapInfo<Type>::getHashValue(key.first),
|
||||||
hash_combine_range(key.second.begin(), key.second.end()));
|
hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
|
static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
return lhs == KeyTy(rhs->elementType, rhs->getShape());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
|
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
|
||||||
// Ranked tensors are uniqued based on their element type and shape.
|
// Ranked tensors are uniqued based on their element type and shape.
|
||||||
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
using KeyTy = std::pair<Type, ArrayRef<int>>;
|
||||||
using DenseMapInfo<RankedTensorType *>::getHashValue;
|
using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<RankedTensorType *>::isEqual;
|
using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
DenseMapInfo<Type *>::getHashValue(key.first),
|
DenseMapInfo<Type>::getHashValue(key.first),
|
||||||
hash_combine_range(key.second.begin(), key.second.end()));
|
hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
|
static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
return lhs == KeyTy(rhs->elementType, rhs->getShape());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
|
||||||
// MemRefs are uniqued based on their element type, shape, affine map
|
// MemRefs are uniqued based on their element type, shape, affine map
|
||||||
// composition, and memory space.
|
// composition, and memory space.
|
||||||
using KeyTy =
|
using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
|
||||||
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
|
using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<MemRefType *>::getHashValue;
|
using DenseMapInfo<MemRefTypeStorage *>::isEqual;
|
||||||
using DenseMapInfo<MemRefType *>::isEqual;
|
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
|
DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
|
||||||
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
|
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
|
||||||
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
|
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
|
||||||
std::get<3>(key));
|
std::get<3>(key));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
|
static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
|
return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
|
||||||
rhs->getAffineMaps(), rhs->getMemorySpace());
|
rhs->getAffineMaps(), rhs->memorySpace);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
||||||
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
|
using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
|
||||||
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
|
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
|
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
|
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
|
||||||
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
using KeyTy = std::pair<VectorOrTensorType, StringRef>;
|
||||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
|
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
|
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
|
@ -295,13 +295,14 @@ public:
|
||||||
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
|
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
|
||||||
|
|
||||||
// Uniquing table for 'other' types.
|
// Uniquing table for 'other' types.
|
||||||
OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
|
OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
|
||||||
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
|
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
|
||||||
|
nullptr};
|
||||||
|
|
||||||
// Uniquing table for 'float' types.
|
// Uniquing table for 'float' types.
|
||||||
FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
|
FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
|
||||||
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
|
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
|
||||||
nullptr};
|
{nullptr};
|
||||||
|
|
||||||
// Affine map uniquing.
|
// Affine map uniquing.
|
||||||
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
|
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
|
||||||
|
@ -324,26 +325,26 @@ public:
|
||||||
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
|
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
|
||||||
|
|
||||||
/// Integer type uniquing.
|
/// Integer type uniquing.
|
||||||
DenseMap<unsigned, IntegerType *> integers;
|
DenseMap<unsigned, IntegerTypeStorage *> integers;
|
||||||
|
|
||||||
/// Function type uniquing.
|
/// Function type uniquing.
|
||||||
using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
|
using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
|
||||||
FunctionTypeSet functions;
|
FunctionTypeSet functions;
|
||||||
|
|
||||||
/// Vector type uniquing.
|
/// Vector type uniquing.
|
||||||
using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
|
using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
|
||||||
VectorTypeSet vectors;
|
VectorTypeSet vectors;
|
||||||
|
|
||||||
/// Ranked tensor type uniquing.
|
/// Ranked tensor type uniquing.
|
||||||
using RankedTensorTypeSet =
|
using RankedTensorTypeSet =
|
||||||
DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
|
DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
|
||||||
RankedTensorTypeSet rankedTensors;
|
RankedTensorTypeSet rankedTensors;
|
||||||
|
|
||||||
/// Unranked tensor type uniquing.
|
/// Unranked tensor type uniquing.
|
||||||
DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
|
DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
|
||||||
|
|
||||||
/// MemRef type uniquing.
|
/// MemRef type uniquing.
|
||||||
using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
|
using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
|
||||||
MemRefTypeSet memrefs;
|
MemRefTypeSet memrefs;
|
||||||
|
|
||||||
// Attribute uniquing.
|
// Attribute uniquing.
|
||||||
|
@ -355,13 +356,12 @@ public:
|
||||||
ArrayAttrSet arrayAttrs;
|
ArrayAttrSet arrayAttrs;
|
||||||
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
|
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
|
||||||
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
|
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
|
||||||
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
|
DenseMap<Type, TypeAttributeStorage *> typeAttrs;
|
||||||
using AttributeListSet =
|
using AttributeListSet =
|
||||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||||
AttributeListSet attributeLists;
|
AttributeListSet attributeLists;
|
||||||
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
|
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
|
||||||
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
|
DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
|
||||||
SplatElementsAttributeStorage *>
|
|
||||||
splatElementsAttrs;
|
splatElementsAttrs;
|
||||||
using DenseElementsAttrSet =
|
using DenseElementsAttrSet =
|
||||||
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
|
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
|
||||||
|
@ -369,7 +369,7 @@ public:
|
||||||
using OpaqueElementsAttrSet =
|
using OpaqueElementsAttrSet =
|
||||||
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
|
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
|
||||||
OpaqueElementsAttrSet opaqueElementsAttrs;
|
OpaqueElementsAttrSet opaqueElementsAttrs;
|
||||||
DenseMap<std::tuple<Type *, Attribute, Attribute>,
|
DenseMap<std::tuple<Type, Attribute, Attribute>,
|
||||||
SparseElementsAttributeStorage *>
|
SparseElementsAttributeStorage *>
|
||||||
sparseElementsAttrs;
|
sparseElementsAttrs;
|
||||||
|
|
||||||
|
@ -556,19 +556,20 @@ FileLineColLoc *FileLineColLoc::get(UniquedFilename filename, unsigned line,
|
||||||
// Type uniquing
|
// Type uniquing
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
|
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
||||||
|
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
auto *&result = impl.integers[width];
|
auto *&result = impl.integers[width];
|
||||||
if (!result) {
|
if (!result) {
|
||||||
result = impl.allocator.Allocate<IntegerType>();
|
result = impl.allocator.Allocate<IntegerTypeStorage>();
|
||||||
new (result) IntegerType(width, context);
|
new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatType *FloatType::get(Kind kind, MLIRContext *context) {
|
FloatType FloatType::get(Kind kind, MLIRContext *context) {
|
||||||
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||||
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
|
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
@ -580,16 +581,16 @@ FloatType *FloatType::get(Kind kind, MLIRContext *context) {
|
||||||
return entry;
|
return entry;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *ptr = impl.allocator.Allocate<FloatType>();
|
auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (ptr) FloatType(kind, context);
|
new (ptr) FloatTypeStorage{{kind, context}};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return entry = ptr;
|
return entry = ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
OtherType *OtherType::get(Kind kind, MLIRContext *context) {
|
OtherType OtherType::get(Kind kind, MLIRContext *context) {
|
||||||
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
|
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
|
||||||
"Not an 'other' type kind");
|
"Not an 'other' type kind");
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
@ -600,18 +601,17 @@ OtherType *OtherType::get(Kind kind, MLIRContext *context) {
|
||||||
return entry;
|
return entry;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *ptr = impl.allocator.Allocate<OtherType>();
|
auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (ptr) OtherType(kind, context);
|
new (ptr) OtherTypeStorage{{kind, context}};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return entry = ptr;
|
return entry = ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
|
||||||
ArrayRef<Type *> results,
|
MLIRContext *context) {
|
||||||
MLIRContext *context) {
|
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this function type.
|
// Look to see if we already have this function type.
|
||||||
|
@ -623,32 +623,34 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<FunctionType>();
|
auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
|
||||||
|
|
||||||
// Copy the inputs and results into the bump pointer.
|
// Copy the inputs and results into the bump pointer.
|
||||||
SmallVector<Type *, 16> types;
|
SmallVector<Type, 16> types;
|
||||||
types.reserve(inputs.size() + results.size());
|
types.reserve(inputs.size() + results.size());
|
||||||
types.append(inputs.begin(), inputs.end());
|
types.append(inputs.begin(), inputs.end());
|
||||||
types.append(results.begin(), results.end());
|
types.append(results.begin(), results.end());
|
||||||
auto typesList = impl.copyInto(ArrayRef<Type *>(types));
|
auto typesList = impl.copyInto(ArrayRef<Type>(types));
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result)
|
new (result) FunctionTypeStorage{
|
||||||
FunctionType(typesList.data(), inputs.size(), results.size(), context);
|
{Kind::Function, context, static_cast<unsigned int>(inputs.size())},
|
||||||
|
static_cast<unsigned int>(results.size()),
|
||||||
|
typesList.data()};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
|
||||||
assert(!shape.empty() && "vector types must have at least one dimension");
|
assert(!shape.empty() && "vector types must have at least one dimension");
|
||||||
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
|
assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
|
||||||
"vectors elements must be primitives");
|
"vectors elements must be primitives");
|
||||||
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
|
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
|
||||||
return i < 0;
|
return i < 0;
|
||||||
}) && "vector types must have static shape");
|
}) && "vector types must have static shape");
|
||||||
|
|
||||||
auto *context = elementType->getContext();
|
auto *context = elementType.getContext();
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this vector type.
|
// Look to see if we already have this vector type.
|
||||||
|
@ -660,21 +662,23 @@ VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<VectorType>();
|
auto *result = impl.allocator.Allocate<VectorTypeStorage>();
|
||||||
|
|
||||||
// Copy the shape into the bump pointer.
|
// Copy the shape into the bump pointer.
|
||||||
shape = impl.copyInto(shape);
|
shape = impl.copyInto(shape);
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result) VectorType(shape, elementType, context);
|
new (result) VectorTypeStorage{
|
||||||
|
{{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
|
||||||
|
elementType},
|
||||||
|
shape.data()};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
|
||||||
Type *elementType) {
|
auto *context = elementType.getContext();
|
||||||
auto *context = elementType->getContext();
|
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this ranked tensor type.
|
// Look to see if we already have this ranked tensor type.
|
||||||
|
@ -686,20 +690,23 @@ RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<RankedTensorType>();
|
auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
|
||||||
|
|
||||||
// Copy the shape into the bump pointer.
|
// Copy the shape into the bump pointer.
|
||||||
shape = impl.copyInto(shape);
|
shape = impl.copyInto(shape);
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result) RankedTensorType(shape, elementType, context);
|
new (result) RankedTensorTypeStorage{
|
||||||
|
{{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
|
||||||
|
elementType}},
|
||||||
|
shape.data()};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||||
auto *context = elementType->getContext();
|
auto *context = elementType.getContext();
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this unranked tensor type.
|
// Look to see if we already have this unranked tensor type.
|
||||||
|
@ -710,17 +717,18 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
result = impl.allocator.Allocate<UnrankedTensorType>();
|
result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result) UnrankedTensorType(elementType, context);
|
new (result) UnrankedTensorTypeStorage{
|
||||||
|
{{{Kind::UnrankedTensor, context}, elementType}}};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition,
|
ArrayRef<AffineMap> affineMapComposition,
|
||||||
unsigned memorySpace) {
|
unsigned memorySpace) {
|
||||||
auto *context = elementType->getContext();
|
auto *context = elementType.getContext();
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Drop the unbounded identity maps from the composition.
|
// Drop the unbounded identity maps from the composition.
|
||||||
|
@ -744,7 +752,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<MemRefType>();
|
auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
|
||||||
|
|
||||||
// Copy the shape into the bump pointer.
|
// Copy the shape into the bump pointer.
|
||||||
shape = impl.copyInto(shape);
|
shape = impl.copyInto(shape);
|
||||||
|
@ -755,8 +763,13 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
||||||
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
|
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
|
new (result) MemRefTypeStorage{
|
||||||
context);
|
{Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
|
||||||
|
elementType,
|
||||||
|
shape.data(),
|
||||||
|
static_cast<unsigned int>(affineMapComposition.size()),
|
||||||
|
affineMapComposition.data(),
|
||||||
|
memorySpace};
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
@ -895,7 +908,7 @@ IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
|
TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
|
||||||
auto *&result = context->getImpl().typeAttrs[type];
|
auto *&result = context->getImpl().typeAttrs[type];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
@ -1009,9 +1022,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
|
||||||
Attribute elt) {
|
Attribute elt) {
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type.getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if we already have this.
|
||||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||||
|
@ -1030,14 +1043,14 @@ SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||||
ArrayRef<char> data) {
|
ArrayRef<char> data) {
|
||||||
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
|
||||||
(void)bitsRequired;
|
(void)bitsRequired;
|
||||||
assert((bitsRequired <= data.size() * 8L) &&
|
assert((bitsRequired <= data.size() * 8L) &&
|
||||||
"Input data bit size should be larger than that type requires");
|
"Input data bit size should be larger than that type requires");
|
||||||
|
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type.getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if this constant is already defined.
|
// Look to see if this constant is already defined.
|
||||||
DenseElementsAttrInfo::KeyTy key({type, data});
|
DenseElementsAttrInfo::KeyTy key({type, data});
|
||||||
|
@ -1048,8 +1061,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// Otherwise, allocate a new one, unique it and return it.
|
// Otherwise, allocate a new one, unique it and return it.
|
||||||
auto *eltType = type->getElementType();
|
auto eltType = type.getElementType();
|
||||||
switch (eltType->getKind()) {
|
switch (eltType.getKind()) {
|
||||||
case Type::Kind::BF16:
|
case Type::Kind::BF16:
|
||||||
case Type::Kind::F16:
|
case Type::Kind::F16:
|
||||||
case Type::Kind::F32:
|
case Type::Kind::F32:
|
||||||
|
@ -1064,7 +1077,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
case Type::Kind::Integer: {
|
case Type::Kind::Integer: {
|
||||||
auto width = ::cast<IntegerType>(eltType)->getWidth();
|
auto width = eltType.cast<IntegerType>().getWidth();
|
||||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||||
|
@ -1080,12 +1093,12 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
|
||||||
StringRef bytes) {
|
StringRef bytes) {
|
||||||
assert(isValidTensorElementType(type->getElementType()) &&
|
assert(isValidTensorElementType(type.getElementType()) &&
|
||||||
"Input element type should be a valid tensor element type");
|
"Input element type should be a valid tensor element type");
|
||||||
|
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type.getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if this constant is already defined.
|
// Look to see if this constant is already defined.
|
||||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||||
|
@ -1104,10 +1117,10 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
|
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
||||||
DenseIntElementsAttr indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr values) {
|
DenseElementsAttr values) {
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type.getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if we already have this.
|
||||||
auto key = std::make_tuple(type, indices, values);
|
auto key = std::make_tuple(type, indices, values);
|
||||||
|
|
|
@ -377,7 +377,7 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
||||||
auto *type = op->getResult(0)->getType();
|
auto type = op->getResult(0)->getType();
|
||||||
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
||||||
if (op->getResult(i)->getType() != type)
|
if (op->getResult(i)->getType() != type)
|
||||||
return op->emitOpError(
|
return op->emitOpError(
|
||||||
|
@ -393,19 +393,19 @@ bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
||||||
|
|
||||||
/// If this is a vector type, or a tensor type, return the scalar element type
|
/// If this is a vector type, or a tensor type, return the scalar element type
|
||||||
/// that it is built around, otherwise return the type unmodified.
|
/// that it is built around, otherwise return the type unmodified.
|
||||||
static Type *getTensorOrVectorElementType(Type *type) {
|
static Type getTensorOrVectorElementType(Type type) {
|
||||||
if (auto *vec = dyn_cast<VectorType>(type))
|
if (auto vec = type.dyn_cast<VectorType>())
|
||||||
return vec->getElementType();
|
return vec.getElementType();
|
||||||
|
|
||||||
// Look through tensor<vector<...>> to find the underlying element type.
|
// Look through tensor<vector<...>> to find the underlying element type.
|
||||||
if (auto *tensor = dyn_cast<TensorType>(type))
|
if (auto tensor = type.dyn_cast<TensorType>())
|
||||||
return getTensorOrVectorElementType(tensor->getElementType());
|
return getTensorOrVectorElementType(tensor.getElementType());
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
||||||
for (auto *result : op->getResults()) {
|
for (auto *result : op->getResults()) {
|
||||||
if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
|
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
|
||||||
return op->emitOpError("requires a floating point type");
|
return op->emitOpError("requires a floating point type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,7 +414,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
||||||
|
|
||||||
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
|
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
|
||||||
for (auto *result : op->getResults()) {
|
for (auto *result : op->getResults()) {
|
||||||
if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
|
if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
|
||||||
return op->emitOpError("requires an integer type");
|
return op->emitOpError("requires an integer type");
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -436,7 +436,7 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result,
|
||||||
|
|
||||||
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
|
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||||
Type *type;
|
Type type;
|
||||||
return parser->parseOperandList(ops, 2) ||
|
return parser->parseOperandList(ops, 2) ||
|
||||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||||
parser->parseColonType(type) ||
|
parser->parseColonType(type) ||
|
||||||
|
@ -448,7 +448,7 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
||||||
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
|
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
|
||||||
<< *op->getOperand(1);
|
<< *op->getOperand(1);
|
||||||
p->printOptionalAttrDict(op->getAttrs());
|
p->printOptionalAttrDict(op->getAttrs());
|
||||||
*p << " : " << *op->getResult(0)->getType();
|
*p << " : " << op->getResult(0)->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -456,14 +456,14 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void impl::buildCastOp(Builder *builder, OperationState *result,
|
void impl::buildCastOp(Builder *builder, OperationState *result,
|
||||||
SSAValue *source, Type *destType) {
|
SSAValue *source, Type destType) {
|
||||||
result->addOperands(source);
|
result->addOperands(source);
|
||||||
result->addTypes(destType);
|
result->addTypes(destType);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType srcInfo;
|
OpAsmParser::OperandType srcInfo;
|
||||||
Type *srcType, *dstType;
|
Type srcType, dstType;
|
||||||
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
|
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
|
||||||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
|
parser->resolveOperand(srcInfo, srcType, result->operands) ||
|
||||||
parser->parseKeywordType("to", dstType) ||
|
parser->parseKeywordType("to", dstType) ||
|
||||||
|
@ -472,5 +472,5 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
||||||
|
|
||||||
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
|
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
|
||||||
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
|
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
|
||||||
<< *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
|
<< op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,7 +239,7 @@ void Statement::moveBefore(StmtBlock *block,
|
||||||
/// Create a new OperationStmt with the specific fields.
|
/// Create a new OperationStmt with the specific fields.
|
||||||
OperationStmt *OperationStmt::create(Location *location, OperationName name,
|
OperationStmt *OperationStmt::create(Location *location, OperationName name,
|
||||||
ArrayRef<MLValue *> operands,
|
ArrayRef<MLValue *> operands,
|
||||||
ArrayRef<Type *> resultTypes,
|
ArrayRef<Type> resultTypes,
|
||||||
ArrayRef<NamedAttribute> attributes,
|
ArrayRef<NamedAttribute> attributes,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
|
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
|
||||||
|
@ -288,9 +288,9 @@ MLIRContext *OperationStmt::getContext() const {
|
||||||
// If we have a result or operand type, that is a constant time way to get
|
// If we have a result or operand type, that is a constant time way to get
|
||||||
// to the context.
|
// to the context.
|
||||||
if (getNumResults())
|
if (getNumResults())
|
||||||
return getResult(0)->getType()->getContext();
|
return getResult(0)->getType().getContext();
|
||||||
if (getNumOperands())
|
if (getNumOperands())
|
||||||
return getOperand(0)->getType()->getContext();
|
return getOperand(0)->getType().getContext();
|
||||||
|
|
||||||
// In the very odd case where we have no operands or results, fall back to
|
// In the very odd case where we have no operands or results, fall back to
|
||||||
// doing a find.
|
// doing a find.
|
||||||
|
@ -474,7 +474,7 @@ MLIRContext *IfStmt::getContext() const {
|
||||||
if (operands.empty())
|
if (operands.empty())
|
||||||
return findFunction()->getContext();
|
return findFunction()->getContext();
|
||||||
|
|
||||||
return getOperand(0)->getType()->getContext();
|
return getOperand(0)->getType().getContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -501,7 +501,7 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
||||||
operands.push_back(remapOperand(opValue));
|
operands.push_back(remapOperand(opValue));
|
||||||
|
|
||||||
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
|
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
|
||||||
SmallVector<Type *, 8> resultTypes;
|
SmallVector<Type, 8> resultTypes;
|
||||||
resultTypes.reserve(opStmt->getNumResults());
|
resultTypes.reserve(opStmt->getNumResults());
|
||||||
for (auto *result : opStmt->getResults())
|
for (auto *result : opStmt->getResults())
|
||||||
resultTypes.push_back(result->getType());
|
resultTypes.push_back(result->getType());
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This holds implementation details of Type.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#ifndef TYPEDETAIL_H_
|
||||||
|
#define TYPEDETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class AffineMap;
|
||||||
|
class MLIRContext;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Base storage class appearing in a Type.
|
||||||
|
struct alignas(8) TypeStorage {
|
||||||
|
TypeStorage(Type::Kind kind, MLIRContext *context)
|
||||||
|
: context(context), kind(kind), subclassData(0) {}
|
||||||
|
TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
|
||||||
|
: context(context), kind(kind), subclassData(subclassData) {}
|
||||||
|
|
||||||
|
unsigned getSubclassData() const { return subclassData; }
|
||||||
|
|
||||||
|
void setSubclassData(unsigned val) {
|
||||||
|
subclassData = val;
|
||||||
|
// Ensure we don't have any accidental truncation.
|
||||||
|
assert(getSubclassData() == val && "Subclass data too large for field");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This refers to the MLIRContext in which this type was uniqued.
|
||||||
|
MLIRContext *const context;
|
||||||
|
|
||||||
|
/// Classification of the subclass, used for type checking.
|
||||||
|
Type::Kind kind : 8;
|
||||||
|
|
||||||
|
/// Space for subclasses to store data.
|
||||||
|
unsigned subclassData : 24;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct IntegerTypeStorage : public TypeStorage {
|
||||||
|
unsigned width;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FloatTypeStorage : public TypeStorage {};
|
||||||
|
|
||||||
|
struct OtherTypeStorage : public TypeStorage {};
|
||||||
|
|
||||||
|
struct FunctionTypeStorage : public TypeStorage {
|
||||||
|
ArrayRef<Type> getInputs() const {
|
||||||
|
return ArrayRef<Type>(inputsAndResults, subclassData);
|
||||||
|
}
|
||||||
|
ArrayRef<Type> getResults() const {
|
||||||
|
return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned numResults;
|
||||||
|
Type const *inputsAndResults;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VectorOrTensorTypeStorage : public TypeStorage {
|
||||||
|
Type elementType;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VectorTypeStorage : public VectorOrTensorTypeStorage {
|
||||||
|
ArrayRef<int> getShape() const {
|
||||||
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int *shapeElements;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
|
||||||
|
|
||||||
|
struct RankedTensorTypeStorage : public TensorTypeStorage {
|
||||||
|
ArrayRef<int> getShape() const {
|
||||||
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int *shapeElements;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
|
||||||
|
|
||||||
|
struct MemRefTypeStorage : public TypeStorage {
|
||||||
|
ArrayRef<int> getShape() const {
|
||||||
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<AffineMap> getAffineMaps() const {
|
||||||
|
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The type of each scalar element of the memref.
|
||||||
|
Type elementType;
|
||||||
|
/// An array of integers which stores the shape dimension sizes.
|
||||||
|
const int *shapeElements;
|
||||||
|
/// The number of affine maps in the 'affineMapList' array.
|
||||||
|
const unsigned numAffineMaps;
|
||||||
|
/// List of affine maps in the memref's layout/index map composition.
|
||||||
|
AffineMap const *affineMapList;
|
||||||
|
/// Memory space in which data referenced by memref resides.
|
||||||
|
const unsigned memorySpace;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace mlir
|
||||||
|
#endif // TYPEDETAIL_H_
|
|
@ -16,10 +16,17 @@
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "TypeDetail.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/Support/STLExtras.h"
|
#include "mlir/Support/STLExtras.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace mlir::detail;
|
||||||
|
|
||||||
|
Type::Kind Type::getKind() const { return type->kind; }
|
||||||
|
|
||||||
|
MLIRContext *Type::getContext() const { return type->context; }
|
||||||
|
|
||||||
unsigned Type::getBitWidth() const {
|
unsigned Type::getBitWidth() const {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
|
@ -32,34 +39,49 @@ unsigned Type::getBitWidth() const {
|
||||||
case Type::Kind::F64:
|
case Type::Kind::F64:
|
||||||
return 64;
|
return 64;
|
||||||
case Type::Kind::Integer:
|
case Type::Kind::Integer:
|
||||||
return cast<IntegerType>(this)->getWidth();
|
return cast<IntegerType>().getWidth();
|
||||||
case Type::Kind::Vector:
|
case Type::Kind::Vector:
|
||||||
case Type::Kind::RankedTensor:
|
case Type::Kind::RankedTensor:
|
||||||
case Type::Kind::UnrankedTensor:
|
case Type::Kind::UnrankedTensor:
|
||||||
return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
|
return cast<VectorOrTensorType>().getElementType().getBitWidth();
|
||||||
// TODO: Handle more types.
|
// TODO: Handle more types.
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("unexpected type");
|
llvm_unreachable("unexpected type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType::IntegerType(unsigned width, MLIRContext *context)
|
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
|
||||||
: Type(Kind::Integer, context), width(width) {
|
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
|
||||||
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
|
|
||||||
|
IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
|
|
||||||
|
unsigned IntegerType::getWidth() const {
|
||||||
|
return static_cast<ImplType *>(type)->width;
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
|
FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
|
|
||||||
OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
|
OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
|
|
||||||
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
unsigned numResults, MLIRContext *context)
|
|
||||||
: Type(Kind::Function, context, numInputs), numResults(numResults),
|
|
||||||
inputsAndResults(inputsAndResults) {}
|
|
||||||
|
|
||||||
VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
|
ArrayRef<Type> FunctionType::getInputs() const {
|
||||||
Type *elementType, unsigned subClassData)
|
return static_cast<ImplType *>(type)->getInputs();
|
||||||
: Type(kind, context, subClassData), elementType(elementType) {}
|
}
|
||||||
|
|
||||||
|
unsigned FunctionType::getNumResults() const {
|
||||||
|
return static_cast<ImplType *>(type)->numResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<Type> FunctionType::getResults() const {
|
||||||
|
return static_cast<ImplType *>(type)->getResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
|
|
||||||
|
Type VectorOrTensorType::getElementType() const {
|
||||||
|
return static_cast<ImplType *>(type)->elementType;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned VectorOrTensorType::getNumElements() const {
|
unsigned VectorOrTensorType::getNumElements() const {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
|
@ -103,11 +125,11 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||||
ArrayRef<int> VectorOrTensorType::getShape() const {
|
ArrayRef<int> VectorOrTensorType::getShape() const {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
case Kind::Vector:
|
case Kind::Vector:
|
||||||
return cast<VectorType>(this)->getShape();
|
return cast<VectorType>().getShape();
|
||||||
case Kind::RankedTensor:
|
case Kind::RankedTensor:
|
||||||
return cast<RankedTensorType>(this)->getShape();
|
return cast<RankedTensorType>().getShape();
|
||||||
case Kind::UnrankedTensor:
|
case Kind::UnrankedTensor:
|
||||||
return cast<RankedTensorType>(this)->getShape();
|
return cast<RankedTensorType>().getShape();
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("not a VectorOrTensorType");
|
llvm_unreachable("not a VectorOrTensorType");
|
||||||
}
|
}
|
||||||
|
@ -118,35 +140,38 @@ bool VectorOrTensorType::hasStaticShape() const {
|
||||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
|
VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
|
||||||
MLIRContext *context)
|
|
||||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
|
||||||
shapeElements(shape.data()) {}
|
|
||||||
|
|
||||||
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
ArrayRef<int> VectorType::getShape() const {
|
||||||
: VectorOrTensorType(kind, context, elementType) {
|
return static_cast<ImplType *>(type)->getShape();
|
||||||
assert(isValidTensorElementType(elementType));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
|
||||||
MLIRContext *context)
|
|
||||||
: TensorType(Kind::RankedTensor, elementType, context),
|
RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
|
||||||
shapeElements(shape.data()) {
|
|
||||||
setSubclassData(shape.size());
|
ArrayRef<int> RankedTensorType::getShape() const {
|
||||||
|
return static_cast<ImplType *>(type)->getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
|
||||||
: TensorType(Kind::UnrankedTensor, elementType, context) {}
|
|
||||||
|
|
||||||
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
|
MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
|
||||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
|
||||||
MLIRContext *context)
|
ArrayRef<int> MemRefType::getShape() const {
|
||||||
: Type(Kind::MemRef, context, shape.size()), elementType(elementType),
|
return static_cast<ImplType *>(type)->getShape();
|
||||||
shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
|
}
|
||||||
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
|
|
||||||
|
Type MemRefType::getElementType() const {
|
||||||
|
return static_cast<ImplType *>(type)->elementType;
|
||||||
|
}
|
||||||
|
|
||||||
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||||
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
|
return static_cast<ImplType *>(type)->getAffineMaps();
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned MemRefType::getMemorySpace() const {
|
||||||
|
return static_cast<ImplType *>(type)->memorySpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned MemRefType::getNumDynamicDims() const {
|
unsigned MemRefType::getNumDynamicDims() const {
|
||||||
|
|
|
@ -182,19 +182,19 @@ public:
|
||||||
// as the results of their action.
|
// as the results of their action.
|
||||||
|
|
||||||
// Type parsing.
|
// Type parsing.
|
||||||
VectorType *parseVectorType();
|
VectorType parseVectorType();
|
||||||
ParseResult parseXInDimensionList();
|
ParseResult parseXInDimensionList();
|
||||||
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
|
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
|
||||||
Type *parseTensorType();
|
Type parseTensorType();
|
||||||
Type *parseMemRefType();
|
Type parseMemRefType();
|
||||||
Type *parseFunctionType();
|
Type parseFunctionType();
|
||||||
Type *parseType();
|
Type parseType();
|
||||||
ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
|
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
|
||||||
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
|
ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
|
||||||
|
|
||||||
// Attribute parsing.
|
// Attribute parsing.
|
||||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||||
FunctionType *type);
|
FunctionType type);
|
||||||
Attribute parseAttribute();
|
Attribute parseAttribute();
|
||||||
|
|
||||||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||||
|
@ -206,9 +206,9 @@ public:
|
||||||
AffineMap parseAffineMapReference();
|
AffineMap parseAffineMapReference();
|
||||||
IntegerSet parseIntegerSetInline();
|
IntegerSet parseIntegerSetInline();
|
||||||
IntegerSet parseIntegerSetReference();
|
IntegerSet parseIntegerSetReference();
|
||||||
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
|
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
|
||||||
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
|
DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
|
||||||
VectorOrTensorType *parseVectorOrTensorType();
|
VectorOrTensorType parseVectorOrTensorType();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The Parser is subclassed and reinstantiated. Do not add additional
|
// The Parser is subclassed and reinstantiated. Do not add additional
|
||||||
|
@ -299,7 +299,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
||||||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||||
/// other-type ::= `index` | `tf_control`
|
/// other-type ::= `index` | `tf_control`
|
||||||
///
|
///
|
||||||
Type *Parser::parseType() {
|
Type Parser::parseType() {
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
default:
|
default:
|
||||||
return (emitError("expected type"), nullptr);
|
return (emitError("expected type"), nullptr);
|
||||||
|
@ -368,7 +368,7 @@ Type *Parser::parseType() {
|
||||||
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
|
||||||
/// const-dimension-list ::= (integer-literal `x`)+
|
/// const-dimension-list ::= (integer-literal `x`)+
|
||||||
///
|
///
|
||||||
VectorType *Parser::parseVectorType() {
|
VectorType Parser::parseVectorType() {
|
||||||
consumeToken(Token::kw_vector);
|
consumeToken(Token::kw_vector);
|
||||||
|
|
||||||
if (parseToken(Token::less, "expected '<' in vector type"))
|
if (parseToken(Token::less, "expected '<' in vector type"))
|
||||||
|
@ -402,11 +402,11 @@ VectorType *Parser::parseVectorType() {
|
||||||
|
|
||||||
// Parse the element type.
|
// Parse the element type.
|
||||||
auto typeLoc = getToken().getLoc();
|
auto typeLoc = getToken().getLoc();
|
||||||
auto *elementType = parseType();
|
auto elementType = parseType();
|
||||||
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
|
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
|
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
|
||||||
return (emitError(typeLoc, "invalid vector element type"), nullptr);
|
return (emitError(typeLoc, "invalid vector element type"), nullptr);
|
||||||
|
|
||||||
return VectorType::get(dimensions, elementType);
|
return VectorType::get(dimensions, elementType);
|
||||||
|
@ -461,7 +461,7 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
|
||||||
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
|
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
|
||||||
/// dimension-list ::= dimension-list-ranked | `*x`
|
/// dimension-list ::= dimension-list-ranked | `*x`
|
||||||
///
|
///
|
||||||
Type *Parser::parseTensorType() {
|
Type Parser::parseTensorType() {
|
||||||
consumeToken(Token::kw_tensor);
|
consumeToken(Token::kw_tensor);
|
||||||
|
|
||||||
if (parseToken(Token::less, "expected '<' in tensor type"))
|
if (parseToken(Token::less, "expected '<' in tensor type"))
|
||||||
|
@ -485,7 +485,7 @@ Type *Parser::parseTensorType() {
|
||||||
|
|
||||||
// Parse the element type.
|
// Parse the element type.
|
||||||
auto typeLoc = getToken().getLoc();
|
auto typeLoc = getToken().getLoc();
|
||||||
auto *elementType = parseType();
|
auto elementType = parseType();
|
||||||
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
|
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -505,7 +505,7 @@ Type *Parser::parseTensorType() {
|
||||||
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
|
||||||
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
/// memory-space ::= integer-literal /* | TODO: address-space-id */
|
||||||
///
|
///
|
||||||
Type *Parser::parseMemRefType() {
|
Type Parser::parseMemRefType() {
|
||||||
consumeToken(Token::kw_memref);
|
consumeToken(Token::kw_memref);
|
||||||
|
|
||||||
if (parseToken(Token::less, "expected '<' in memref type"))
|
if (parseToken(Token::less, "expected '<' in memref type"))
|
||||||
|
@ -517,12 +517,12 @@ Type *Parser::parseMemRefType() {
|
||||||
|
|
||||||
// Parse the element type.
|
// Parse the element type.
|
||||||
auto typeLoc = getToken().getLoc();
|
auto typeLoc = getToken().getLoc();
|
||||||
auto *elementType = parseType();
|
auto elementType = parseType();
|
||||||
if (!elementType)
|
if (!elementType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
|
if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
|
||||||
!isa<VectorType>(elementType))
|
!elementType.isa<VectorType>())
|
||||||
return (emitError(typeLoc, "invalid memref element type"), nullptr);
|
return (emitError(typeLoc, "invalid memref element type"), nullptr);
|
||||||
|
|
||||||
// Parse semi-affine-map-composition.
|
// Parse semi-affine-map-composition.
|
||||||
|
@ -581,10 +581,10 @@ Type *Parser::parseMemRefType() {
|
||||||
///
|
///
|
||||||
/// function-type ::= type-list-parens `->` type-list
|
/// function-type ::= type-list-parens `->` type-list
|
||||||
///
|
///
|
||||||
Type *Parser::parseFunctionType() {
|
Type Parser::parseFunctionType() {
|
||||||
assert(getToken().is(Token::l_paren));
|
assert(getToken().is(Token::l_paren));
|
||||||
|
|
||||||
SmallVector<Type *, 4> arguments, results;
|
SmallVector<Type, 4> arguments, results;
|
||||||
if (parseTypeList(arguments) ||
|
if (parseTypeList(arguments) ||
|
||||||
parseToken(Token::arrow, "expected '->' in function type") ||
|
parseToken(Token::arrow, "expected '->' in function type") ||
|
||||||
parseTypeList(results))
|
parseTypeList(results))
|
||||||
|
@ -598,7 +598,7 @@ Type *Parser::parseFunctionType() {
|
||||||
///
|
///
|
||||||
/// type-list-no-parens ::= type (`,` type)*
|
/// type-list-no-parens ::= type (`,` type)*
|
||||||
///
|
///
|
||||||
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
|
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
auto elt = parseType();
|
auto elt = parseType();
|
||||||
elements.push_back(elt);
|
elements.push_back(elt);
|
||||||
|
@ -615,7 +615,7 @@ ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
|
||||||
/// type-list-parens ::= `(` `)`
|
/// type-list-parens ::= `(` `)`
|
||||||
/// | `(` type-list-no-parens `)`
|
/// | `(` type-list-no-parens `)`
|
||||||
///
|
///
|
||||||
ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
|
ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
auto elt = parseType();
|
auto elt = parseType();
|
||||||
elements.push_back(elt);
|
elements.push_back(elt);
|
||||||
|
@ -639,8 +639,8 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
|
||||||
namespace {
|
namespace {
|
||||||
class TensorLiteralParser {
|
class TensorLiteralParser {
|
||||||
public:
|
public:
|
||||||
TensorLiteralParser(Parser &p, Type *eltTy)
|
TensorLiteralParser(Parser &p, Type eltTy)
|
||||||
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
|
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
|
||||||
|
|
||||||
ParseResult parse() { return parseList(shape); }
|
ParseResult parse() { return parseList(shape); }
|
||||||
|
|
||||||
|
@ -676,7 +676,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
Parser &p;
|
Parser &p;
|
||||||
Type *eltTy;
|
Type eltTy;
|
||||||
size_t currBitPos;
|
size_t currBitPos;
|
||||||
size_t bitsWidth;
|
size_t bitsWidth;
|
||||||
SmallVector<int, 4> shape;
|
SmallVector<int, 4> shape;
|
||||||
|
@ -698,7 +698,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
||||||
if (!result)
|
if (!result)
|
||||||
return p.emitError("expected tensor element");
|
return p.emitError("expected tensor element");
|
||||||
// check result matches the element type.
|
// check result matches the element type.
|
||||||
switch (eltTy->getKind()) {
|
switch (eltTy.getKind()) {
|
||||||
case Type::Kind::BF16:
|
case Type::Kind::BF16:
|
||||||
case Type::Kind::F16:
|
case Type::Kind::F16:
|
||||||
case Type::Kind::F32:
|
case Type::Kind::F32:
|
||||||
|
@ -779,7 +779,7 @@ ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) {
|
||||||
/// synthesizing a forward reference) or emit an error and return null on
|
/// synthesizing a forward reference) or emit an error and return null on
|
||||||
/// failure.
|
/// failure.
|
||||||
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||||
FunctionType *type) {
|
FunctionType type) {
|
||||||
Identifier name = builder.getIdentifier(nameStr.drop_front());
|
Identifier name = builder.getIdentifier(nameStr.drop_front());
|
||||||
|
|
||||||
// See if the function has already been defined in the module.
|
// See if the function has already been defined in the module.
|
||||||
|
@ -902,10 +902,10 @@ Attribute Parser::parseAttribute() {
|
||||||
if (parseToken(Token::colon, "expected ':' and function type"))
|
if (parseToken(Token::colon, "expected ':' and function type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto typeLoc = getToken().getLoc();
|
auto typeLoc = getToken().getLoc();
|
||||||
Type *type = parseType();
|
Type type = parseType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto *fnType = dyn_cast<FunctionType>(type);
|
auto fnType = type.dyn_cast<FunctionType>();
|
||||||
if (!fnType)
|
if (!fnType)
|
||||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||||
|
|
||||||
|
@ -916,7 +916,7 @@ Attribute Parser::parseAttribute() {
|
||||||
consumeToken(Token::kw_opaque);
|
consumeToken(Token::kw_opaque);
|
||||||
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto *type = parseVectorOrTensorType();
|
auto type = parseVectorOrTensorType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto val = getToken().getStringValue();
|
auto val = getToken().getStringValue();
|
||||||
|
@ -937,7 +937,7 @@ Attribute Parser::parseAttribute() {
|
||||||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto *type = parseVectorOrTensorType();
|
auto type = parseVectorOrTensorType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
|
@ -959,7 +959,7 @@ Attribute Parser::parseAttribute() {
|
||||||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto *type = parseVectorOrTensorType();
|
auto type = parseVectorOrTensorType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -981,41 +981,41 @@ Attribute Parser::parseAttribute() {
|
||||||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto *type = parseVectorOrTensorType();
|
auto type = parseVectorOrTensorType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
case Token::l_square: {
|
case Token::l_square: {
|
||||||
/// Parse indices
|
/// Parse indices
|
||||||
auto *indicesEltType = builder.getIntegerType(32);
|
auto indicesEltType = builder.getIntegerType(32);
|
||||||
auto indices =
|
auto indices =
|
||||||
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
|
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
|
||||||
|
|
||||||
if (parseToken(Token::comma, "expected ','"))
|
if (parseToken(Token::comma, "expected ','"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
/// Parse values.
|
/// Parse values.
|
||||||
auto *valuesEltType = type->getElementType();
|
auto valuesEltType = type.getElementType();
|
||||||
auto values =
|
auto values =
|
||||||
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
|
parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
|
||||||
|
|
||||||
/// Sanity check.
|
/// Sanity check.
|
||||||
auto *indicesType = indices.getType();
|
auto indicesType = indices.getType();
|
||||||
auto *valuesType = values.getType();
|
auto valuesType = values.getType();
|
||||||
auto sameShape = (indicesType->getRank() == 1) ||
|
auto sameShape = (indicesType.getRank() == 1) ||
|
||||||
(type->getRank() == indicesType->getDimSize(1));
|
(type.getRank() == indicesType.getDimSize(1));
|
||||||
auto sameElementNum =
|
auto sameElementNum =
|
||||||
indicesType->getDimSize(0) == valuesType->getDimSize(0);
|
indicesType.getDimSize(0) == valuesType.getDimSize(0);
|
||||||
if (!sameShape || !sameElementNum) {
|
if (!sameShape || !sameElementNum) {
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream s(str);
|
llvm::raw_string_ostream s(str);
|
||||||
s << "expected shape ([";
|
s << "expected shape ([";
|
||||||
interleaveComma(type->getShape(), s);
|
interleaveComma(type.getShape(), s);
|
||||||
s << "]); inferred shape of indices literal ([";
|
s << "]); inferred shape of indices literal ([";
|
||||||
interleaveComma(indicesType->getShape(), s);
|
interleaveComma(indicesType.getShape(), s);
|
||||||
s << "]); inferred shape of values literal ([";
|
s << "]); inferred shape of values literal ([";
|
||||||
interleaveComma(valuesType->getShape(), s);
|
interleaveComma(valuesType.getShape(), s);
|
||||||
s << "])";
|
s << "])";
|
||||||
return (emitError(s.str()), nullptr);
|
return (emitError(s.str()), nullptr);
|
||||||
}
|
}
|
||||||
|
@ -1035,7 +1035,7 @@ Attribute Parser::parseAttribute() {
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
if (Type *type = parseType())
|
if (Type type = parseType())
|
||||||
return builder.getTypeAttr(type);
|
return builder.getTypeAttr(type);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1051,12 +1051,12 @@ Attribute Parser::parseAttribute() {
|
||||||
///
|
///
|
||||||
/// This method returns a constructed dense elements attribute with the shape
|
/// This method returns a constructed dense elements attribute with the shape
|
||||||
/// from the parsing result.
|
/// from the parsing result.
|
||||||
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
|
||||||
TensorLiteralParser literalParser(*this, eltType);
|
TensorLiteralParser literalParser(*this, eltType);
|
||||||
if (literalParser.parse())
|
if (literalParser.parse())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
VectorOrTensorType *type;
|
VectorOrTensorType type;
|
||||||
if (isVector) {
|
if (isVector) {
|
||||||
type = builder.getVectorType(literalParser.getShape(), eltType);
|
type = builder.getVectorType(literalParser.getShape(), eltType);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1076,18 +1076,18 @@ DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
||||||
/// This method compares the shapes from the parsing result and that from the
|
/// This method compares the shapes from the parsing result and that from the
|
||||||
/// input argument. It returns a constructed dense elements attribute if both
|
/// input argument. It returns a constructed dense elements attribute if both
|
||||||
/// match.
|
/// match.
|
||||||
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
|
||||||
auto *eltTy = type->getElementType();
|
auto eltTy = type.getElementType();
|
||||||
TensorLiteralParser literalParser(*this, eltTy);
|
TensorLiteralParser literalParser(*this, eltTy);
|
||||||
if (literalParser.parse())
|
if (literalParser.parse())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (literalParser.getShape() != type->getShape()) {
|
if (literalParser.getShape() != type.getShape()) {
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream s(str);
|
llvm::raw_string_ostream s(str);
|
||||||
s << "inferred shape of elements literal ([";
|
s << "inferred shape of elements literal ([";
|
||||||
interleaveComma(literalParser.getShape(), s);
|
interleaveComma(literalParser.getShape(), s);
|
||||||
s << "]) does not match type ([";
|
s << "]) does not match type ([";
|
||||||
interleaveComma(type->getShape(), s);
|
interleaveComma(type.getShape(), s);
|
||||||
s << "])";
|
s << "])";
|
||||||
return (emitError(s.str()), nullptr);
|
return (emitError(s.str()), nullptr);
|
||||||
}
|
}
|
||||||
|
@ -1100,8 +1100,8 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||||
/// vector-or-tensor-type ::= vector-type | tensor-type
|
/// vector-or-tensor-type ::= vector-type | tensor-type
|
||||||
///
|
///
|
||||||
/// This method also checks the type has static shape and ranked.
|
/// This method also checks the type has static shape and ranked.
|
||||||
VectorOrTensorType *Parser::parseVectorOrTensorType() {
|
VectorOrTensorType Parser::parseVectorOrTensorType() {
|
||||||
auto *type = dyn_cast<VectorOrTensorType>(parseType());
|
auto type = parseType().dyn_cast<VectorOrTensorType>();
|
||||||
if (!type) {
|
if (!type) {
|
||||||
return (emitError("expected elements literal has a tensor or vector type"),
|
return (emitError("expected elements literal has a tensor or vector type"),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
@ -1110,7 +1110,7 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() {
|
||||||
if (parseToken(Token::comma, "expected ','"))
|
if (parseToken(Token::comma, "expected ','"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (!type->hasStaticShape() || type->getRank() == -1) {
|
if (!type.hasStaticShape() || type.getRank() == -1) {
|
||||||
return (emitError("tensor literals must be ranked and have static shape"),
|
return (emitError("tensor literals must be ranked and have static shape"),
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
@ -1834,7 +1834,7 @@ public:
|
||||||
|
|
||||||
/// Given a reference to an SSA value and its type, return a reference. This
|
/// Given a reference to an SSA value and its type, return a reference. This
|
||||||
/// returns null on failure.
|
/// returns null on failure.
|
||||||
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
|
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
|
||||||
|
|
||||||
/// Register a definition of a value with the symbol table.
|
/// Register a definition of a value with the symbol table.
|
||||||
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
|
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
|
||||||
|
@ -1845,11 +1845,11 @@ public:
|
||||||
|
|
||||||
template <typename ResultType>
|
template <typename ResultType>
|
||||||
ResultType parseSSADefOrUseAndType(
|
ResultType parseSSADefOrUseAndType(
|
||||||
const std::function<ResultType(SSAUseInfo, Type *)> &action);
|
const std::function<ResultType(SSAUseInfo, Type)> &action);
|
||||||
|
|
||||||
SSAValue *parseSSAUseAndType() {
|
SSAValue *parseSSAUseAndType() {
|
||||||
return parseSSADefOrUseAndType<SSAValue *>(
|
return parseSSADefOrUseAndType<SSAValue *>(
|
||||||
[&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
|
[&](SSAUseInfo useInfo, Type type) -> SSAValue * {
|
||||||
return resolveSSAUse(useInfo, type);
|
return resolveSSAUse(useInfo, type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1880,7 +1880,7 @@ private:
|
||||||
/// their first reference, to allow checking for use of undefined values.
|
/// their first reference, to allow checking for use of undefined values.
|
||||||
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
|
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
|
||||||
|
|
||||||
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
|
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
|
||||||
|
|
||||||
/// Return true if this is a forward reference.
|
/// Return true if this is a forward reference.
|
||||||
bool isForwardReferencePlaceholder(SSAValue *value) {
|
bool isForwardReferencePlaceholder(SSAValue *value) {
|
||||||
|
@ -1891,7 +1891,7 @@ private:
|
||||||
|
|
||||||
/// Create and remember a new placeholder for a forward reference.
|
/// Create and remember a new placeholder for a forward reference.
|
||||||
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
||||||
Type *type) {
|
Type type) {
|
||||||
// Forward references are always created as instructions, even in ML
|
// Forward references are always created as instructions, even in ML
|
||||||
// functions, because we just need something with a def/use chain.
|
// functions, because we just need something with a def/use chain.
|
||||||
//
|
//
|
||||||
|
@ -1908,7 +1908,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
||||||
|
|
||||||
/// Given an unbound reference to an SSA value and its type, return the value
|
/// Given an unbound reference to an SSA value and its type, return the value
|
||||||
/// it specifies. This returns null on failure.
|
/// it specifies. This returns null on failure.
|
||||||
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
|
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
|
||||||
auto &entries = values[useInfo.name];
|
auto &entries = values[useInfo.name];
|
||||||
|
|
||||||
// If we have already seen a value of this name, return it.
|
// If we have already seen a value of this name, return it.
|
||||||
|
@ -2057,14 +2057,14 @@ FunctionParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
|
||||||
/// ssa-use-and-type ::= ssa-use `:` type
|
/// ssa-use-and-type ::= ssa-use `:` type
|
||||||
template <typename ResultType>
|
template <typename ResultType>
|
||||||
ResultType FunctionParser::parseSSADefOrUseAndType(
|
ResultType FunctionParser::parseSSADefOrUseAndType(
|
||||||
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
|
const std::function<ResultType(SSAUseInfo, Type)> &action) {
|
||||||
|
|
||||||
SSAUseInfo useInfo;
|
SSAUseInfo useInfo;
|
||||||
if (parseSSAUse(useInfo) ||
|
if (parseSSAUse(useInfo) ||
|
||||||
parseToken(Token::colon, "expected ':' and type for SSA operand"))
|
parseToken(Token::colon, "expected ':' and type for SSA operand"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto *type = parseType();
|
auto type = parseType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
@ -2101,7 +2101,7 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
|
||||||
if (valueIDs.empty())
|
if (valueIDs.empty())
|
||||||
return ParseSuccess;
|
return ParseSuccess;
|
||||||
|
|
||||||
SmallVector<Type *, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
if (parseToken(Token::colon, "expected ':' in operand list") ||
|
if (parseToken(Token::colon, "expected ':' in operand list") ||
|
||||||
parseTypeListNoParens(types))
|
parseTypeListNoParens(types))
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
@ -2209,14 +2209,14 @@ Operation *FunctionParser::parseVerboseOperation(
|
||||||
auto type = parseType();
|
auto type = parseType();
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto fnType = dyn_cast<FunctionType>(type);
|
auto fnType = type.dyn_cast<FunctionType>();
|
||||||
if (!fnType)
|
if (!fnType)
|
||||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||||
|
|
||||||
result.addTypes(fnType->getResults());
|
result.addTypes(fnType.getResults());
|
||||||
|
|
||||||
// Check that we have the right number of types for the operands.
|
// Check that we have the right number of types for the operands.
|
||||||
auto operandTypes = fnType->getInputs();
|
auto operandTypes = fnType.getInputs();
|
||||||
if (operandTypes.size() != operandInfos.size()) {
|
if (operandTypes.size() != operandInfos.size()) {
|
||||||
auto plural = "s"[operandInfos.size() == 1];
|
auto plural = "s"[operandInfos.size() == 1];
|
||||||
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
|
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
|
||||||
|
@ -2253,17 +2253,17 @@ public:
|
||||||
return parser.parseToken(Token::comma, "expected ','");
|
return parser.parseToken(Token::comma, "expected ','");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool parseColonType(Type *&result) override {
|
bool parseColonType(Type &result) override {
|
||||||
return parser.parseToken(Token::colon, "expected ':'") ||
|
return parser.parseToken(Token::colon, "expected ':'") ||
|
||||||
!(result = parser.parseType());
|
!(result = parser.parseType());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
|
bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
|
||||||
if (parser.parseToken(Token::colon, "expected ':'"))
|
if (parser.parseToken(Token::colon, "expected ':'"))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
if (auto *type = parser.parseType())
|
if (auto type = parser.parseType())
|
||||||
result.push_back(type);
|
result.push_back(type);
|
||||||
else
|
else
|
||||||
return true;
|
return true;
|
||||||
|
@ -2273,7 +2273,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a keyword followed by a type.
|
/// Parse a keyword followed by a type.
|
||||||
bool parseKeywordType(const char *keyword, Type *&result) override {
|
bool parseKeywordType(const char *keyword, Type &result) override {
|
||||||
if (parser.getTokenSpelling() != keyword)
|
if (parser.getTokenSpelling() != keyword)
|
||||||
return parser.emitError("expected '" + Twine(keyword) + "'");
|
return parser.emitError("expected '" + Twine(keyword) + "'");
|
||||||
parser.consumeToken();
|
parser.consumeToken();
|
||||||
|
@ -2396,7 +2396,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve a parse function name and a type into a function reference.
|
/// Resolve a parse function name and a type into a function reference.
|
||||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
virtual bool resolveFunctionName(StringRef name, FunctionType type,
|
||||||
llvm::SMLoc loc, Function *&result) {
|
llvm::SMLoc loc, Function *&result) {
|
||||||
result = parser.resolveFunctionReference(name, loc, type);
|
result = parser.resolveFunctionReference(name, loc, type);
|
||||||
return result == nullptr;
|
return result == nullptr;
|
||||||
|
@ -2410,7 +2410,7 @@ public:
|
||||||
|
|
||||||
llvm::SMLoc getNameLoc() const override { return nameLoc; }
|
llvm::SMLoc getNameLoc() const override { return nameLoc; }
|
||||||
|
|
||||||
bool resolveOperand(const OperandType &operand, Type *type,
|
bool resolveOperand(const OperandType &operand, Type type,
|
||||||
SmallVectorImpl<SSAValue *> &result) override {
|
SmallVectorImpl<SSAValue *> &result) override {
|
||||||
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
|
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
|
||||||
operand.location};
|
operand.location};
|
||||||
|
@ -2559,11 +2559,11 @@ ParseResult CFGFunctionParser::parseOptionalBasicBlockArgList(
|
||||||
return ParseSuccess;
|
return ParseSuccess;
|
||||||
|
|
||||||
return parseCommaSeparatedList([&]() -> ParseResult {
|
return parseCommaSeparatedList([&]() -> ParseResult {
|
||||||
auto type = parseSSADefOrUseAndType<Type *>(
|
auto type = parseSSADefOrUseAndType<Type>(
|
||||||
[&](SSAUseInfo useInfo, Type *type) -> Type * {
|
[&](SSAUseInfo useInfo, Type type) -> Type {
|
||||||
BBArgument *arg = owner->addArgument(type);
|
BBArgument *arg = owner->addArgument(type);
|
||||||
if (addDefinition(useInfo, arg))
|
if (addDefinition(useInfo, arg))
|
||||||
return nullptr;
|
return {};
|
||||||
return type;
|
return type;
|
||||||
});
|
});
|
||||||
return type ? ParseSuccess : ParseFailure;
|
return type ? ParseSuccess : ParseFailure;
|
||||||
|
@ -2908,7 +2908,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
||||||
" symbol count must match");
|
" symbol count must match");
|
||||||
|
|
||||||
// Resolve SSA uses.
|
// Resolve SSA uses.
|
||||||
Type *indexType = builder.getIndexType();
|
Type indexType = builder.getIndexType();
|
||||||
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
|
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
|
||||||
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
|
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
|
||||||
if (!sval)
|
if (!sval)
|
||||||
|
@ -3187,9 +3187,9 @@ private:
|
||||||
ParseResult parseAffineStructureDef();
|
ParseResult parseAffineStructureDef();
|
||||||
|
|
||||||
// Functions.
|
// Functions.
|
||||||
ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||||
SmallVectorImpl<StringRef> &argNames);
|
SmallVectorImpl<StringRef> &argNames);
|
||||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
|
ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||||
SmallVectorImpl<StringRef> *argNames);
|
SmallVectorImpl<StringRef> *argNames);
|
||||||
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
|
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
|
||||||
ParseResult parseExtFunc();
|
ParseResult parseExtFunc();
|
||||||
|
@ -3248,7 +3248,7 @@ ParseResult ModuleParser::parseAffineStructureDef() {
|
||||||
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
|
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
|
||||||
///
|
///
|
||||||
ParseResult
|
ParseResult
|
||||||
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||||
SmallVectorImpl<StringRef> &argNames) {
|
SmallVectorImpl<StringRef> &argNames) {
|
||||||
consumeToken(Token::l_paren);
|
consumeToken(Token::l_paren);
|
||||||
|
|
||||||
|
@ -3284,7 +3284,7 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
|
||||||
/// type-list)?
|
/// type-list)?
|
||||||
///
|
///
|
||||||
ParseResult
|
ParseResult
|
||||||
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||||
SmallVectorImpl<StringRef> *argNames) {
|
SmallVectorImpl<StringRef> *argNames) {
|
||||||
if (getToken().isNot(Token::at_identifier))
|
if (getToken().isNot(Token::at_identifier))
|
||||||
return emitError("expected a function identifier like '@foo'");
|
return emitError("expected a function identifier like '@foo'");
|
||||||
|
@ -3295,7 +3295,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||||
if (getToken().isNot(Token::l_paren))
|
if (getToken().isNot(Token::l_paren))
|
||||||
return emitError("expected '(' in function signature");
|
return emitError("expected '(' in function signature");
|
||||||
|
|
||||||
SmallVector<Type *, 4> argTypes;
|
SmallVector<Type, 4> argTypes;
|
||||||
ParseResult parseResult;
|
ParseResult parseResult;
|
||||||
|
|
||||||
if (argNames)
|
if (argNames)
|
||||||
|
@ -3307,7 +3307,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
|
||||||
// Parse the return type if present.
|
// Parse the return type if present.
|
||||||
SmallVector<Type *, 4> results;
|
SmallVector<Type, 4> results;
|
||||||
if (consumeIf(Token::arrow)) {
|
if (consumeIf(Token::arrow)) {
|
||||||
if (parseTypeList(results))
|
if (parseTypeList(results))
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
@ -3340,7 +3340,7 @@ ParseResult ModuleParser::parseExtFunc() {
|
||||||
auto loc = getToken().getLoc();
|
auto loc = getToken().getLoc();
|
||||||
|
|
||||||
StringRef name;
|
StringRef name;
|
||||||
FunctionType *type = nullptr;
|
FunctionType type;
|
||||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
|
||||||
|
@ -3372,7 +3372,7 @@ ParseResult ModuleParser::parseCFGFunc() {
|
||||||
auto loc = getToken().getLoc();
|
auto loc = getToken().getLoc();
|
||||||
|
|
||||||
StringRef name;
|
StringRef name;
|
||||||
FunctionType *type = nullptr;
|
FunctionType type;
|
||||||
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
|
||||||
return ParseFailure;
|
return ParseFailure;
|
||||||
|
|
||||||
|
@ -3405,7 +3405,7 @@ ParseResult ModuleParser::parseMLFunc() {
|
||||||
consumeToken(Token::kw_mlfunc);
|
consumeToken(Token::kw_mlfunc);
|
||||||
|
|
||||||
StringRef name;
|
StringRef name;
|
||||||
FunctionType *type = nullptr;
|
FunctionType type;
|
||||||
SmallVector<StringRef, 4> argNames;
|
SmallVector<StringRef, 4> argNames;
|
||||||
|
|
||||||
auto loc = getToken().getLoc();
|
auto loc = getToken().getLoc();
|
||||||
|
|
|
@ -138,23 +138,23 @@ void AddIOp::getCanonicalizationPatterns(OwningPatternList &results,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void AllocOp::build(Builder *builder, OperationState *result,
|
void AllocOp::build(Builder *builder, OperationState *result,
|
||||||
MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
|
MemRefType memrefType, ArrayRef<SSAValue *> operands) {
|
||||||
result->addOperands(operands);
|
result->addOperands(operands);
|
||||||
result->types.push_back(memrefType);
|
result->types.push_back(memrefType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllocOp::print(OpAsmPrinter *p) const {
|
void AllocOp::print(OpAsmPrinter *p) const {
|
||||||
MemRefType *type = getType();
|
MemRefType type = getType();
|
||||||
*p << "alloc";
|
*p << "alloc";
|
||||||
// Print dynamic dimension operands.
|
// Print dynamic dimension operands.
|
||||||
printDimAndSymbolList(operand_begin(), operand_end(),
|
printDimAndSymbolList(operand_begin(), operand_end(),
|
||||||
type->getNumDynamicDims(), p);
|
type.getNumDynamicDims(), p);
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
||||||
*p << " : " << *type;
|
*p << " : " << type;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
MemRefType *type;
|
MemRefType type;
|
||||||
|
|
||||||
// Parse the dimension operands and optional symbol operands, followed by a
|
// Parse the dimension operands and optional symbol operands, followed by a
|
||||||
// memref type.
|
// memref type.
|
||||||
|
@ -170,7 +170,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
// Verification still checks that the total number of operands matches
|
// Verification still checks that the total number of operands matches
|
||||||
// the number of symbols in the affine map, plus the number of dynamic
|
// the number of symbols in the affine map, plus the number of dynamic
|
||||||
// dimensions in the memref.
|
// dimensions in the memref.
|
||||||
if (numDimOperands != type->getNumDynamicDims()) {
|
if (numDimOperands != type.getNumDynamicDims()) {
|
||||||
return parser->emitError(parser->getNameLoc(),
|
return parser->emitError(parser->getNameLoc(),
|
||||||
"dimension operand count does not equal memref "
|
"dimension operand count does not equal memref "
|
||||||
"dynamic dimension count");
|
"dynamic dimension count");
|
||||||
|
@ -180,13 +180,13 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllocOp::verify() const {
|
bool AllocOp::verify() const {
|
||||||
auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
|
auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
|
||||||
if (!memRefType)
|
if (!memRefType)
|
||||||
return emitOpError("result must be a memref");
|
return emitOpError("result must be a memref");
|
||||||
|
|
||||||
unsigned numSymbols = 0;
|
unsigned numSymbols = 0;
|
||||||
if (!memRefType->getAffineMaps().empty()) {
|
if (!memRefType.getAffineMaps().empty()) {
|
||||||
AffineMap affineMap = memRefType->getAffineMaps()[0];
|
AffineMap affineMap = memRefType.getAffineMaps()[0];
|
||||||
// Store number of symbols used in affine map (used in subsequent check).
|
// Store number of symbols used in affine map (used in subsequent check).
|
||||||
numSymbols = affineMap.getNumSymbols();
|
numSymbols = affineMap.getNumSymbols();
|
||||||
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
|
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
|
||||||
|
@ -195,10 +195,10 @@ bool AllocOp::verify() const {
|
||||||
// Remove when we can emit errors directly from *Type::get(...) functions.
|
// Remove when we can emit errors directly from *Type::get(...) functions.
|
||||||
//
|
//
|
||||||
// Verify that the layout affine map matches the rank of the memref.
|
// Verify that the layout affine map matches the rank of the memref.
|
||||||
if (affineMap.getNumDims() != memRefType->getRank())
|
if (affineMap.getNumDims() != memRefType.getRank())
|
||||||
return emitOpError("affine map dimension count must equal memref rank");
|
return emitOpError("affine map dimension count must equal memref rank");
|
||||||
}
|
}
|
||||||
unsigned numDynamicDims = memRefType->getNumDynamicDims();
|
unsigned numDynamicDims = memRefType.getNumDynamicDims();
|
||||||
// Check that the total number of operands matches the number of symbols in
|
// Check that the total number of operands matches the number of symbols in
|
||||||
// the affine map, plus the number of dynamic dimensions specified in the
|
// the affine map, plus the number of dynamic dimensions specified in the
|
||||||
// memref type.
|
// memref type.
|
||||||
|
@ -208,7 +208,7 @@ bool AllocOp::verify() const {
|
||||||
}
|
}
|
||||||
// Verify that all operands are of type Index.
|
// Verify that all operands are of type Index.
|
||||||
for (auto *operand : getOperands()) {
|
for (auto *operand : getOperands()) {
|
||||||
if (!operand->getType()->isIndex())
|
if (!operand->getType().isIndex())
|
||||||
return emitOpError("requires operands to be of type Index");
|
return emitOpError("requires operands to be of type Index");
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -239,13 +239,13 @@ struct SimplifyAllocConst : public Pattern {
|
||||||
// Ok, we have one or more constant operands. Collect the non-constant ones
|
// Ok, we have one or more constant operands. Collect the non-constant ones
|
||||||
// and keep track of the resultant memref type to build.
|
// and keep track of the resultant memref type to build.
|
||||||
SmallVector<int, 4> newShapeConstants;
|
SmallVector<int, 4> newShapeConstants;
|
||||||
newShapeConstants.reserve(memrefType->getRank());
|
newShapeConstants.reserve(memrefType.getRank());
|
||||||
SmallVector<SSAValue *, 4> newOperands;
|
SmallVector<SSAValue *, 4> newOperands;
|
||||||
SmallVector<SSAValue *, 4> droppedOperands;
|
SmallVector<SSAValue *, 4> droppedOperands;
|
||||||
|
|
||||||
unsigned dynamicDimPos = 0;
|
unsigned dynamicDimPos = 0;
|
||||||
for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
|
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
||||||
int dimSize = memrefType->getDimSize(dim);
|
int dimSize = memrefType.getDimSize(dim);
|
||||||
// If this is already static dimension, keep it.
|
// If this is already static dimension, keep it.
|
||||||
if (dimSize != -1) {
|
if (dimSize != -1) {
|
||||||
newShapeConstants.push_back(dimSize);
|
newShapeConstants.push_back(dimSize);
|
||||||
|
@ -267,10 +267,10 @@ struct SimplifyAllocConst : public Pattern {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new memref type (which will have fewer dynamic dimensions).
|
// Create new memref type (which will have fewer dynamic dimensions).
|
||||||
auto *newMemRefType = MemRefType::get(
|
auto newMemRefType = MemRefType::get(
|
||||||
newShapeConstants, memrefType->getElementType(),
|
newShapeConstants, memrefType.getElementType(),
|
||||||
memrefType->getAffineMaps(), memrefType->getMemorySpace());
|
memrefType.getAffineMaps(), memrefType.getMemorySpace());
|
||||||
assert(newOperands.size() == newMemRefType->getNumDynamicDims());
|
assert(newOperands.size() == newMemRefType.getNumDynamicDims());
|
||||||
|
|
||||||
// Create and insert the alloc op for the new memref.
|
// Create and insert the alloc op for the new memref.
|
||||||
auto newAlloc =
|
auto newAlloc =
|
||||||
|
@ -297,13 +297,13 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
||||||
ArrayRef<SSAValue *> operands) {
|
ArrayRef<SSAValue *> operands) {
|
||||||
result->addOperands(operands);
|
result->addOperands(operands);
|
||||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||||
result->addTypes(callee->getType()->getResults());
|
result->addTypes(callee->getType().getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
StringRef calleeName;
|
StringRef calleeName;
|
||||||
llvm::SMLoc calleeLoc;
|
llvm::SMLoc calleeLoc;
|
||||||
FunctionType *calleeType = nullptr;
|
FunctionType calleeType;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||||
Function *callee = nullptr;
|
Function *callee = nullptr;
|
||||||
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
||||||
|
@ -312,8 +312,8 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||||
parser->parseColonType(calleeType) ||
|
parser->parseColonType(calleeType) ||
|
||||||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
||||||
parser->addTypesToList(calleeType->getResults(), result->types) ||
|
parser->addTypesToList(calleeType.getResults(), result->types) ||
|
||||||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
|
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
||||||
result->operands))
|
result->operands))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
@ -328,7 +328,7 @@ void CallOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getOperands());
|
p->printOperands(getOperands());
|
||||||
*p << ')';
|
*p << ')';
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||||
*p << " : " << *getCallee()->getType();
|
*p << " : " << getCallee()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CallOp::verify() const {
|
bool CallOp::verify() const {
|
||||||
|
@ -338,20 +338,20 @@ bool CallOp::verify() const {
|
||||||
return emitOpError("requires a 'callee' function attribute");
|
return emitOpError("requires a 'callee' function attribute");
|
||||||
|
|
||||||
// Verify that the operand and result types match the callee.
|
// Verify that the operand and result types match the callee.
|
||||||
auto *fnType = fnAttr.getValue()->getType();
|
auto fnType = fnAttr.getValue()->getType();
|
||||||
if (fnType->getNumInputs() != getNumOperands())
|
if (fnType.getNumInputs() != getNumOperands())
|
||||||
return emitOpError("incorrect number of operands for callee");
|
return emitOpError("incorrect number of operands for callee");
|
||||||
|
|
||||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||||
if (getOperand(i)->getType() != fnType->getInput(i))
|
if (getOperand(i)->getType() != fnType.getInput(i))
|
||||||
return emitOpError("operand type mismatch");
|
return emitOpError("operand type mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fnType->getNumResults() != getNumResults())
|
if (fnType.getNumResults() != getNumResults())
|
||||||
return emitOpError("incorrect number of results for callee");
|
return emitOpError("incorrect number of results for callee");
|
||||||
|
|
||||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||||
if (getResult(i)->getType() != fnType->getResult(i))
|
if (getResult(i)->getType() != fnType.getResult(i))
|
||||||
return emitOpError("result type mismatch");
|
return emitOpError("result type mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,14 +364,14 @@ bool CallOp::verify() const {
|
||||||
|
|
||||||
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
||||||
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
||||||
auto *fnType = cast<FunctionType>(callee->getType());
|
auto fnType = callee->getType().cast<FunctionType>();
|
||||||
result->operands.push_back(callee);
|
result->operands.push_back(callee);
|
||||||
result->addOperands(operands);
|
result->addOperands(operands);
|
||||||
result->addTypes(fnType->getResults());
|
result->addTypes(fnType.getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
FunctionType *calleeType = nullptr;
|
FunctionType calleeType;
|
||||||
OpAsmParser::OperandType callee;
|
OpAsmParser::OperandType callee;
|
||||||
llvm::SMLoc operandsLoc;
|
llvm::SMLoc operandsLoc;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||||
|
@ -382,9 +382,9 @@ bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||||
parser->parseColonType(calleeType) ||
|
parser->parseColonType(calleeType) ||
|
||||||
parser->resolveOperand(callee, calleeType, result->operands) ||
|
parser->resolveOperand(callee, calleeType, result->operands) ||
|
||||||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
|
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
|
||||||
result->operands) ||
|
result->operands) ||
|
||||||
parser->addTypesToList(calleeType->getResults(), result->types);
|
parser->addTypesToList(calleeType.getResults(), result->types);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
||||||
|
@ -395,29 +395,29 @@ void CallIndirectOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(++operandRange.begin(), operandRange.end());
|
p->printOperands(++operandRange.begin(), operandRange.end());
|
||||||
*p << ')';
|
*p << ')';
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||||
*p << " : " << *getCallee()->getType();
|
*p << " : " << getCallee()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CallIndirectOp::verify() const {
|
bool CallIndirectOp::verify() const {
|
||||||
// The callee must be a function.
|
// The callee must be a function.
|
||||||
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
|
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
|
||||||
if (!fnType)
|
if (!fnType)
|
||||||
return emitOpError("callee must have function type");
|
return emitOpError("callee must have function type");
|
||||||
|
|
||||||
// Verify that the operand and result types match the callee.
|
// Verify that the operand and result types match the callee.
|
||||||
if (fnType->getNumInputs() != getNumOperands() - 1)
|
if (fnType.getNumInputs() != getNumOperands() - 1)
|
||||||
return emitOpError("incorrect number of operands for callee");
|
return emitOpError("incorrect number of operands for callee");
|
||||||
|
|
||||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||||
if (getOperand(i + 1)->getType() != fnType->getInput(i))
|
if (getOperand(i + 1)->getType() != fnType.getInput(i))
|
||||||
return emitOpError("operand type mismatch");
|
return emitOpError("operand type mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fnType->getNumResults() != getNumResults())
|
if (fnType.getNumResults() != getNumResults())
|
||||||
return emitOpError("incorrect number of results for callee");
|
return emitOpError("incorrect number of results for callee");
|
||||||
|
|
||||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||||
if (getResult(i)->getType() != fnType->getResult(i))
|
if (getResult(i)->getType() != fnType.getResult(i))
|
||||||
return emitOpError("result type mismatch");
|
return emitOpError("result type mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,19 +434,19 @@ void DeallocOp::build(Builder *builder, OperationState *result,
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeallocOp::print(OpAsmPrinter *p) const {
|
void DeallocOp::print(OpAsmPrinter *p) const {
|
||||||
*p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
|
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType memrefInfo;
|
OpAsmParser::OperandType memrefInfo;
|
||||||
MemRefType *type;
|
MemRefType type;
|
||||||
|
|
||||||
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
|
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
|
||||||
parser->resolveOperand(memrefInfo, type, result->operands);
|
parser->resolveOperand(memrefInfo, type, result->operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DeallocOp::verify() const {
|
bool DeallocOp::verify() const {
|
||||||
if (!isa<MemRefType>(getMemRef()->getType()))
|
if (!getMemRef()->getType().isa<MemRefType>())
|
||||||
return emitOpError("operand must be a memref");
|
return emitOpError("operand must be a memref");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -472,13 +472,13 @@ void DimOp::build(Builder *builder, OperationState *result,
|
||||||
void DimOp::print(OpAsmPrinter *p) const {
|
void DimOp::print(OpAsmPrinter *p) const {
|
||||||
*p << "dim " << *getOperand() << ", " << getIndex();
|
*p << "dim " << *getOperand() << ", " << getIndex();
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
|
||||||
*p << " : " << *getOperand()->getType();
|
*p << " : " << getOperand()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType operandInfo;
|
OpAsmParser::OperandType operandInfo;
|
||||||
IntegerAttr indexAttr;
|
IntegerAttr indexAttr;
|
||||||
Type *type;
|
Type type;
|
||||||
|
|
||||||
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
||||||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
|
parser->parseAttribute(indexAttr, "index", result->attributes) ||
|
||||||
|
@ -496,15 +496,15 @@ bool DimOp::verify() const {
|
||||||
return emitOpError("requires an integer attribute named 'index'");
|
return emitOpError("requires an integer attribute named 'index'");
|
||||||
uint64_t index = (uint64_t)indexAttr.getValue();
|
uint64_t index = (uint64_t)indexAttr.getValue();
|
||||||
|
|
||||||
auto *type = getOperand()->getType();
|
auto type = getOperand()->getType();
|
||||||
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
if (index >= tensorType->getRank())
|
if (index >= tensorType.getRank())
|
||||||
return emitOpError("index is out of range");
|
return emitOpError("index is out of range");
|
||||||
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
|
} else if (auto memrefType = type.dyn_cast<MemRefType>()) {
|
||||||
if (index >= memrefType->getRank())
|
if (index >= memrefType.getRank())
|
||||||
return emitOpError("index is out of range");
|
return emitOpError("index is out of range");
|
||||||
|
|
||||||
} else if (isa<UnrankedTensorType>(type)) {
|
} else if (type.isa<UnrankedTensorType>()) {
|
||||||
// ok, assumed to be in-range.
|
// ok, assumed to be in-range.
|
||||||
} else {
|
} else {
|
||||||
return emitOpError("requires an operand with tensor or memref type");
|
return emitOpError("requires an operand with tensor or memref type");
|
||||||
|
@ -516,12 +516,12 @@ bool DimOp::verify() const {
|
||||||
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
// Constant fold dim when the size along the index referred to is a constant.
|
// Constant fold dim when the size along the index referred to is a constant.
|
||||||
auto *opType = getOperand()->getType();
|
auto opType = getOperand()->getType();
|
||||||
int indexSize = -1;
|
int indexSize = -1;
|
||||||
if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
|
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||||
indexSize = tensorType->getShape()[getIndex()];
|
indexSize = tensorType.getShape()[getIndex()];
|
||||||
} else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
|
} else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
|
||||||
indexSize = memrefType->getShape()[getIndex()];
|
indexSize = memrefType.getShape()[getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (indexSize >= 0)
|
if (indexSize >= 0)
|
||||||
|
@ -544,9 +544,9 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getTagIndices());
|
p->printOperands(getTagIndices());
|
||||||
*p << ']';
|
*p << ']';
|
||||||
p->printOptionalAttrDict(getAttrs());
|
p->printOptionalAttrDict(getAttrs());
|
||||||
*p << " : " << *getSrcMemRef()->getType();
|
*p << " : " << getSrcMemRef()->getType();
|
||||||
*p << ", " << *getDstMemRef()->getType();
|
*p << ", " << getDstMemRef()->getType();
|
||||||
*p << ", " << *getTagMemRef()->getType();
|
*p << ", " << getTagMemRef()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse DmaStartOp.
|
// Parse DmaStartOp.
|
||||||
|
@ -566,8 +566,8 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType tagMemrefInfo;
|
OpAsmParser::OperandType tagMemrefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
|
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
|
||||||
|
|
||||||
SmallVector<Type *, 3> types;
|
SmallVector<Type, 3> types;
|
||||||
auto *indexType = parser->getBuilder().getIndexType();
|
auto indexType = parser->getBuilder().getIndexType();
|
||||||
|
|
||||||
// Parse and resolve the following list of operands:
|
// Parse and resolve the following list of operands:
|
||||||
// *) source memref followed by its indices (in square brackets).
|
// *) source memref followed by its indices (in square brackets).
|
||||||
|
@ -601,12 +601,12 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// Check that source/destination index list size matches associated rank.
|
// Check that source/destination index list size matches associated rank.
|
||||||
if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
|
if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
|
||||||
dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
|
dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
|
||||||
return parser->emitError(parser->getNameLoc(),
|
return parser->emitError(parser->getNameLoc(),
|
||||||
"memref rank not equal to indices count");
|
"memref rank not equal to indices count");
|
||||||
|
|
||||||
if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
|
if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
|
||||||
return parser->emitError(parser->getNameLoc(),
|
return parser->emitError(parser->getNameLoc(),
|
||||||
"tag memref rank not equal to indices count");
|
"tag memref rank not equal to indices count");
|
||||||
|
|
||||||
|
@ -632,7 +632,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getTagIndices());
|
p->printOperands(getTagIndices());
|
||||||
*p << "], ";
|
*p << "], ";
|
||||||
p->printOperand(getNumElements());
|
p->printOperand(getNumElements());
|
||||||
*p << " : " << *getTagMemRef()->getType();
|
*p << " : " << getTagMemRef()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse DmaWaitOp.
|
// Parse DmaWaitOp.
|
||||||
|
@ -642,8 +642,8 @@ void DmaWaitOp::print(OpAsmPrinter *p) const {
|
||||||
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType tagMemrefInfo;
|
OpAsmParser::OperandType tagMemrefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
||||||
Type *type;
|
Type type;
|
||||||
auto *indexType = parser->getBuilder().getIndexType();
|
auto indexType = parser->getBuilder().getIndexType();
|
||||||
OpAsmParser::OperandType numElementsInfo;
|
OpAsmParser::OperandType numElementsInfo;
|
||||||
|
|
||||||
// Parse tag memref, its indices, and dma size.
|
// Parse tag memref, its indices, and dma size.
|
||||||
|
@ -657,7 +657,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
|
if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
|
||||||
return parser->emitError(parser->getNameLoc(),
|
return parser->emitError(parser->getNameLoc(),
|
||||||
"tag memref rank not equal to indices count");
|
"tag memref rank not equal to indices count");
|
||||||
|
|
||||||
|
@ -678,10 +678,10 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningPatternList &results,
|
||||||
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
||||||
SSAValue *aggregate,
|
SSAValue *aggregate,
|
||||||
ArrayRef<SSAValue *> indices) {
|
ArrayRef<SSAValue *> indices) {
|
||||||
auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
|
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
|
||||||
result->addOperands(aggregate);
|
result->addOperands(aggregate);
|
||||||
result->addOperands(indices);
|
result->addOperands(indices);
|
||||||
result->types.push_back(aggregateType->getElementType());
|
result->types.push_back(aggregateType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExtractElementOp::print(OpAsmPrinter *p) const {
|
void ExtractElementOp::print(OpAsmPrinter *p) const {
|
||||||
|
@ -689,13 +689,13 @@ void ExtractElementOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getIndices());
|
p->printOperands(getIndices());
|
||||||
*p << ']';
|
*p << ']';
|
||||||
p->printOptionalAttrDict(getAttrs());
|
p->printOptionalAttrDict(getAttrs());
|
||||||
*p << " : " << *getAggregate()->getType();
|
*p << " : " << getAggregate()->getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType aggregateInfo;
|
OpAsmParser::OperandType aggregateInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
VectorOrTensorType *type;
|
VectorOrTensorType type;
|
||||||
|
|
||||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||||
return parser->parseOperand(aggregateInfo) ||
|
return parser->parseOperand(aggregateInfo) ||
|
||||||
|
@ -705,26 +705,26 @@ bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->parseColonType(type) ||
|
parser->parseColonType(type) ||
|
||||||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
|
parser->resolveOperand(aggregateInfo, type, result->operands) ||
|
||||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
||||||
parser->addTypeToList(type->getElementType(), result->types);
|
parser->addTypeToList(type.getElementType(), result->types);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ExtractElementOp::verify() const {
|
bool ExtractElementOp::verify() const {
|
||||||
if (getNumOperands() == 0)
|
if (getNumOperands() == 0)
|
||||||
return emitOpError("expected an aggregate to index into");
|
return emitOpError("expected an aggregate to index into");
|
||||||
|
|
||||||
auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
|
auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
|
||||||
if (!aggregateType)
|
if (!aggregateType)
|
||||||
return emitOpError("first operand must be a vector or tensor");
|
return emitOpError("first operand must be a vector or tensor");
|
||||||
|
|
||||||
if (getType() != aggregateType->getElementType())
|
if (getType() != aggregateType.getElementType())
|
||||||
return emitOpError("result type must match element type of aggregate");
|
return emitOpError("result type must match element type of aggregate");
|
||||||
|
|
||||||
for (auto *idx : getIndices())
|
for (auto *idx : getIndices())
|
||||||
if (!idx->getType()->isIndex())
|
if (!idx->getType().isIndex())
|
||||||
return emitOpError("index to extract_element must have 'index' type");
|
return emitOpError("index to extract_element must have 'index' type");
|
||||||
|
|
||||||
// Verify the # indices match if we have a ranked type.
|
// Verify the # indices match if we have a ranked type.
|
||||||
auto aggregateRank = aggregateType->getRank();
|
auto aggregateRank = aggregateType.getRank();
|
||||||
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
||||||
return emitOpError("incorrect number of indices for extract_element");
|
return emitOpError("incorrect number of indices for extract_element");
|
||||||
|
|
||||||
|
@ -737,10 +737,10 @@ bool ExtractElementOp::verify() const {
|
||||||
|
|
||||||
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
|
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
|
||||||
ArrayRef<SSAValue *> indices) {
|
ArrayRef<SSAValue *> indices) {
|
||||||
auto *memrefType = cast<MemRefType>(memref->getType());
|
auto memrefType = memref->getType().cast<MemRefType>();
|
||||||
result->addOperands(memref);
|
result->addOperands(memref);
|
||||||
result->addOperands(indices);
|
result->addOperands(indices);
|
||||||
result->types.push_back(memrefType->getElementType());
|
result->types.push_back(memrefType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadOp::print(OpAsmPrinter *p) const {
|
void LoadOp::print(OpAsmPrinter *p) const {
|
||||||
|
@ -748,13 +748,13 @@ void LoadOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getIndices());
|
p->printOperands(getIndices());
|
||||||
*p << ']';
|
*p << ']';
|
||||||
p->printOptionalAttrDict(getAttrs());
|
p->printOptionalAttrDict(getAttrs());
|
||||||
*p << " : " << *getMemRefType();
|
*p << " : " << getMemRefType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType memrefInfo;
|
OpAsmParser::OperandType memrefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
MemRefType *type;
|
MemRefType type;
|
||||||
|
|
||||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||||
return parser->parseOperand(memrefInfo) ||
|
return parser->parseOperand(memrefInfo) ||
|
||||||
|
@ -764,25 +764,25 @@ bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->parseColonType(type) ||
|
parser->parseColonType(type) ||
|
||||||
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
parser->resolveOperand(memrefInfo, type, result->operands) ||
|
||||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
|
||||||
parser->addTypeToList(type->getElementType(), result->types);
|
parser->addTypeToList(type.getElementType(), result->types);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LoadOp::verify() const {
|
bool LoadOp::verify() const {
|
||||||
if (getNumOperands() == 0)
|
if (getNumOperands() == 0)
|
||||||
return emitOpError("expected a memref to load from");
|
return emitOpError("expected a memref to load from");
|
||||||
|
|
||||||
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
||||||
if (!memRefType)
|
if (!memRefType)
|
||||||
return emitOpError("first operand must be a memref");
|
return emitOpError("first operand must be a memref");
|
||||||
|
|
||||||
if (getType() != memRefType->getElementType())
|
if (getType() != memRefType.getElementType())
|
||||||
return emitOpError("result type must match element type of memref");
|
return emitOpError("result type must match element type of memref");
|
||||||
|
|
||||||
if (memRefType->getRank() != getNumOperands() - 1)
|
if (memRefType.getRank() != getNumOperands() - 1)
|
||||||
return emitOpError("incorrect number of indices for load");
|
return emitOpError("incorrect number of indices for load");
|
||||||
|
|
||||||
for (auto *idx : getIndices())
|
for (auto *idx : getIndices())
|
||||||
if (!idx->getType()->isIndex())
|
if (!idx->getType().isIndex())
|
||||||
return emitOpError("index to load must have 'index' type");
|
return emitOpError("index to load must have 'index' type");
|
||||||
|
|
||||||
// TODO: Verify we have the right number of indices.
|
// TODO: Verify we have the right number of indices.
|
||||||
|
@ -804,31 +804,31 @@ void LoadOp::getCanonicalizationPatterns(OwningPatternList &results,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool MemRefCastOp::verify() const {
|
bool MemRefCastOp::verify() const {
|
||||||
auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
|
auto opType = getOperand()->getType().dyn_cast<MemRefType>();
|
||||||
auto *resType = dyn_cast<MemRefType>(getType());
|
auto resType = getType().dyn_cast<MemRefType>();
|
||||||
if (!opType || !resType)
|
if (!opType || !resType)
|
||||||
return emitOpError("requires input and result types to be memrefs");
|
return emitOpError("requires input and result types to be memrefs");
|
||||||
|
|
||||||
if (opType == resType)
|
if (opType == resType)
|
||||||
return emitOpError("requires the input and result type to be different");
|
return emitOpError("requires the input and result type to be different");
|
||||||
|
|
||||||
if (opType->getElementType() != resType->getElementType())
|
if (opType.getElementType() != resType.getElementType())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"requires input and result element types to be the same");
|
"requires input and result element types to be the same");
|
||||||
|
|
||||||
if (opType->getAffineMaps() != resType->getAffineMaps())
|
if (opType.getAffineMaps() != resType.getAffineMaps())
|
||||||
return emitOpError("requires input and result mappings to be the same");
|
return emitOpError("requires input and result mappings to be the same");
|
||||||
|
|
||||||
if (opType->getMemorySpace() != resType->getMemorySpace())
|
if (opType.getMemorySpace() != resType.getMemorySpace())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"requires input and result memory spaces to be the same");
|
"requires input and result memory spaces to be the same");
|
||||||
|
|
||||||
// They must have the same rank, and any specified dimensions must match.
|
// They must have the same rank, and any specified dimensions must match.
|
||||||
if (opType->getRank() != resType->getRank())
|
if (opType.getRank() != resType.getRank())
|
||||||
return emitOpError("requires input and result ranks to match");
|
return emitOpError("requires input and result ranks to match");
|
||||||
|
|
||||||
for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
|
for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
|
||||||
int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
|
int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
|
||||||
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
||||||
return emitOpError("requires static dimensions to match");
|
return emitOpError("requires static dimensions to match");
|
||||||
}
|
}
|
||||||
|
@ -923,14 +923,14 @@ void StoreOp::print(OpAsmPrinter *p) const {
|
||||||
p->printOperands(getIndices());
|
p->printOperands(getIndices());
|
||||||
*p << ']';
|
*p << ']';
|
||||||
p->printOptionalAttrDict(getAttrs());
|
p->printOptionalAttrDict(getAttrs());
|
||||||
*p << " : " << *getMemRefType();
|
*p << " : " << getMemRefType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType storeValueInfo;
|
OpAsmParser::OperandType storeValueInfo;
|
||||||
OpAsmParser::OperandType memrefInfo;
|
OpAsmParser::OperandType memrefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
MemRefType *memrefType;
|
MemRefType memrefType;
|
||||||
|
|
||||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||||
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
||||||
|
@ -939,7 +939,7 @@ bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::Delimiter::Square) ||
|
OpAsmParser::Delimiter::Square) ||
|
||||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||||
parser->parseColonType(memrefType) ||
|
parser->parseColonType(memrefType) ||
|
||||||
parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
|
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
|
||||||
result->operands) ||
|
result->operands) ||
|
||||||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
||||||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
|
||||||
|
@ -950,19 +950,19 @@ bool StoreOp::verify() const {
|
||||||
return emitOpError("expected a value to store and a memref");
|
return emitOpError("expected a value to store and a memref");
|
||||||
|
|
||||||
// Second operand is a memref type.
|
// Second operand is a memref type.
|
||||||
auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
|
auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
|
||||||
if (!memRefType)
|
if (!memRefType)
|
||||||
return emitOpError("second operand must be a memref");
|
return emitOpError("second operand must be a memref");
|
||||||
|
|
||||||
// First operand must have same type as memref element type.
|
// First operand must have same type as memref element type.
|
||||||
if (getValueToStore()->getType() != memRefType->getElementType())
|
if (getValueToStore()->getType() != memRefType.getElementType())
|
||||||
return emitOpError("first operand must have same type memref element type");
|
return emitOpError("first operand must have same type memref element type");
|
||||||
|
|
||||||
if (getNumOperands() != 2 + memRefType->getRank())
|
if (getNumOperands() != 2 + memRefType.getRank())
|
||||||
return emitOpError("store index operand count not equal to memref rank");
|
return emitOpError("store index operand count not equal to memref rank");
|
||||||
|
|
||||||
for (auto *idx : getIndices())
|
for (auto *idx : getIndices())
|
||||||
if (!idx->getType()->isIndex())
|
if (!idx->getType().isIndex())
|
||||||
return emitOpError("index to load must have 'index' type");
|
return emitOpError("index to load must have 'index' type");
|
||||||
|
|
||||||
// TODO: Verify we have the right number of indices.
|
// TODO: Verify we have the right number of indices.
|
||||||
|
@ -1046,31 +1046,31 @@ void SubIOp::getCanonicalizationPatterns(OwningPatternList &results,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool TensorCastOp::verify() const {
|
bool TensorCastOp::verify() const {
|
||||||
auto *opType = dyn_cast<TensorType>(getOperand()->getType());
|
auto opType = getOperand()->getType().dyn_cast<TensorType>();
|
||||||
auto *resType = dyn_cast<TensorType>(getType());
|
auto resType = getType().dyn_cast<TensorType>();
|
||||||
if (!opType || !resType)
|
if (!opType || !resType)
|
||||||
return emitOpError("requires input and result types to be tensors");
|
return emitOpError("requires input and result types to be tensors");
|
||||||
|
|
||||||
if (opType == resType)
|
if (opType == resType)
|
||||||
return emitOpError("requires the input and result type to be different");
|
return emitOpError("requires the input and result type to be different");
|
||||||
|
|
||||||
if (opType->getElementType() != resType->getElementType())
|
if (opType.getElementType() != resType.getElementType())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"requires input and result element types to be the same");
|
"requires input and result element types to be the same");
|
||||||
|
|
||||||
// If the source or destination are unranked, then the cast is valid.
|
// If the source or destination are unranked, then the cast is valid.
|
||||||
auto *opRType = dyn_cast<RankedTensorType>(opType);
|
auto opRType = opType.dyn_cast<RankedTensorType>();
|
||||||
auto *resRType = dyn_cast<RankedTensorType>(resType);
|
auto resRType = resType.dyn_cast<RankedTensorType>();
|
||||||
if (!opRType || !resRType)
|
if (!opRType || !resRType)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// If they are both ranked, they have to have the same rank, and any specified
|
// If they are both ranked, they have to have the same rank, and any specified
|
||||||
// dimensions must match.
|
// dimensions must match.
|
||||||
if (opRType->getRank() != resRType->getRank())
|
if (opRType.getRank() != resRType.getRank())
|
||||||
return emitOpError("requires input and result ranks to match");
|
return emitOpError("requires input and result ranks to match");
|
||||||
|
|
||||||
for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
|
for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
|
||||||
int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
|
int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
|
||||||
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
|
||||||
return emitOpError("requires static dimensions to match");
|
return emitOpError("requires static dimensions to match");
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
||||||
SmallVector<SSAValue *, 8> existingConstants;
|
SmallVector<SSAValue *, 8> existingConstants;
|
||||||
// Operation statements that were folded and that need to be erased.
|
// Operation statements that were folded and that need to be erased.
|
||||||
std::vector<OperationStmt *> opStmtsToErase;
|
std::vector<OperationStmt *> opStmtsToErase;
|
||||||
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
|
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
|
||||||
|
|
||||||
bool foldOperation(Operation *op,
|
bool foldOperation(Operation *op,
|
||||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||||
|
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||||
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
||||||
auto &inst = *instIt++;
|
auto &inst = *instIt++;
|
||||||
|
|
||||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||||
builder.setInsertionPoint(&inst);
|
builder.setInsertionPoint(&inst);
|
||||||
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
||||||
};
|
};
|
||||||
|
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||||
|
|
||||||
// Override the walker's operation statement visit for constant folding.
|
// Override the walker's operation statement visit for constant folding.
|
||||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||||
MLFuncBuilder builder(stmt);
|
MLFuncBuilder builder(stmt);
|
||||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||||
};
|
};
|
||||||
|
|
|
@ -77,23 +77,23 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
|
||||||
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
||||||
|
|
||||||
// Doubles the shape with a leading dimension extent of 2.
|
// Doubles the shape with a leading dimension extent of 2.
|
||||||
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
|
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
|
||||||
// Add the leading dimension in the shape for the double buffer.
|
// Add the leading dimension in the shape for the double buffer.
|
||||||
ArrayRef<int> shape = oldMemRefType->getShape();
|
ArrayRef<int> shape = oldMemRefType.getShape();
|
||||||
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
||||||
shapeSizes.insert(shapeSizes.begin(), 2);
|
shapeSizes.insert(shapeSizes.begin(), 2);
|
||||||
|
|
||||||
auto *newMemRefType =
|
auto newMemRefType =
|
||||||
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
|
bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
|
||||||
oldMemRefType->getMemorySpace());
|
oldMemRefType.getMemorySpace());
|
||||||
return newMemRefType;
|
return newMemRefType;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
|
auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
|
||||||
|
|
||||||
// Create and place the alloc at the top level.
|
// Create and place the alloc at the top level.
|
||||||
MLFuncBuilder topBuilder(forStmt->getFunction());
|
MLFuncBuilder topBuilder(forStmt->getFunction());
|
||||||
auto *newMemRef = cast<MLValue>(
|
auto newMemRef = cast<MLValue>(
|
||||||
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
|
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
|
||||||
->getResult());
|
->getResult());
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ private:
|
||||||
/// As part of canonicalization, we move constants to the top of the entry
|
/// As part of canonicalization, we move constants to the top of the entry
|
||||||
/// block of the current function and de-duplicate them. This keeps track of
|
/// block of the current function and de-duplicate them. This keeps track of
|
||||||
/// constants we have done this for.
|
/// constants we have done this for.
|
||||||
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
|
DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
|
||||||
};
|
};
|
||||||
}; // end anonymous namespace
|
}; // end anonymous namespace
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
||||||
MLValue *newMemRef,
|
MLValue *newMemRef,
|
||||||
ArrayRef<MLValue *> extraIndices,
|
ArrayRef<MLValue *> extraIndices,
|
||||||
AffineMap indexRemap) {
|
AffineMap indexRemap) {
|
||||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||||
(void)newMemRefRank; // unused in opt mode
|
(void)newMemRefRank; // unused in opt mode
|
||||||
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
|
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||||
(void)newMemRefRank;
|
(void)newMemRefRank;
|
||||||
if (indexRemap) {
|
if (indexRemap) {
|
||||||
assert(indexRemap.getNumInputs() == oldMemRefRank);
|
assert(indexRemap.getNumInputs() == oldMemRefRank);
|
||||||
|
@ -64,8 +64,8 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert same elemental type.
|
// Assert same elemental type.
|
||||||
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
|
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
|
||||||
cast<MemRefType>(newMemRef->getType())->getElementType());
|
newMemRef->getType().cast<MemRefType>().getElementType());
|
||||||
|
|
||||||
// Check if memref was used in a non-deferencing context.
|
// Check if memref was used in a non-deferencing context.
|
||||||
for (const StmtOperand &use : oldMemRef->getUses()) {
|
for (const StmtOperand &use : oldMemRef->getUses()) {
|
||||||
|
@ -139,7 +139,7 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
||||||
opStmt->operand_end());
|
opStmt->operand_end());
|
||||||
|
|
||||||
// Result types don't change. Both memref's are of the same elemental type.
|
// Result types don't change. Both memref's are of the same elemental type.
|
||||||
SmallVector<Type *, 8> resultTypes;
|
SmallVector<Type, 8> resultTypes;
|
||||||
resultTypes.reserve(opStmt->getNumResults());
|
resultTypes.reserve(opStmt->getNumResults());
|
||||||
for (const auto *result : opStmt->getResults())
|
for (const auto *result : opStmt->getResults())
|
||||||
resultTypes.push_back(result->getType());
|
resultTypes.push_back(result->getType());
|
||||||
|
|
|
@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches,
|
||||||
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
|
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
|
||||||
/// tmpl. The MemRef should be promoted to a closer memory address space in a
|
/// tmpl. The MemRef should be promoted to a closer memory address space in a
|
||||||
/// later pass.
|
/// later pass.
|
||||||
static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
|
static MemRefType getVectorizedMemRefType(MemRefType tmpl,
|
||||||
ArrayRef<int> vectorSizes) {
|
ArrayRef<int> vectorSizes) {
|
||||||
auto *elementType = tmpl->getElementType();
|
auto elementType = tmpl.getElementType();
|
||||||
assert(!dyn_cast<VectorType>(elementType) &&
|
assert(!elementType.dyn_cast<VectorType>() &&
|
||||||
"Can't vectorize an already vector type");
|
"Can't vectorize an already vector type");
|
||||||
assert(tmpl->getAffineMaps().empty() &&
|
assert(tmpl.getAffineMaps().empty() &&
|
||||||
"Unsupported non-implicit identity map");
|
"Unsupported non-implicit identity map");
|
||||||
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
|
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
|
||||||
tmpl->getMemorySpace());
|
tmpl.getMemorySpace());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an unaligned load with the following semantics:
|
/// Creates an unaligned load with the following semantics:
|
||||||
|
@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc,
|
||||||
operands.insert(operands.end(), dstMemRef);
|
operands.insert(operands.end(), dstMemRef);
|
||||||
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
||||||
using functional::map;
|
using functional::map;
|
||||||
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
|
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
|
||||||
return v->getType();
|
return v->getType();
|
||||||
};
|
};
|
||||||
auto types = map(getType, operands);
|
auto types = map(getType, operands);
|
||||||
|
@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
|
||||||
operands.insert(operands.end(), dstMemRef);
|
operands.insert(operands.end(), dstMemRef);
|
||||||
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
|
||||||
using functional::map;
|
using functional::map;
|
||||||
std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
|
std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
|
||||||
return v->getType();
|
return v->getType();
|
||||||
};
|
};
|
||||||
auto types = map(getType, operands);
|
auto types = map(getType, operands);
|
||||||
|
@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() {
|
||||||
template <typename LoadOrStoreOpPointer>
|
template <typename LoadOrStoreOpPointer>
|
||||||
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
|
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
|
||||||
ArrayRef<int> vectorSize) {
|
ArrayRef<int> vectorSize) {
|
||||||
auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
|
auto memRefType =
|
||||||
auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
|
memoryOp->getMemRef()->getType().template cast<MemRefType>();
|
||||||
|
auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
|
||||||
|
|
||||||
// Materialize a MemRef with 1 vector.
|
// Materialize a MemRef with 1 vector.
|
||||||
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
|
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
|
||||||
|
|
Loading…
Reference in New Issue