forked from OSchip/llvm-project
Make IndexType a standard type instead of a builtin. This also cleans up some unnecessary factory methods on the Type class.
PiperOrigin-RevId: 233640730
This commit is contained in:
parent
8de7f6c471
commit
4755774d16
|
@ -434,13 +434,14 @@ PYBIND11_MODULE(pybind, m) {
|
|||
},
|
||||
py::arg("type"), py::arg("bitwidth") = 0,
|
||||
"Returns a scalar mlir::Type using the following convention:\n"
|
||||
" - makeScalarType(c, \"bf16\") return an `mlir::Type::getBF16`\n"
|
||||
" - makeScalarType(c, \"f16\") return an `mlir::Type::getF16`\n"
|
||||
" - makeScalarType(c, \"f32\") return an `mlir::Type::getF32`\n"
|
||||
" - makeScalarType(c, \"f64\") return an `mlir::Type::getF64`\n"
|
||||
" - makeScalarType(c, \"index\") return an `mlir::Type::getIndex`\n"
|
||||
" - makeScalarType(c, \"bf16\") return an "
|
||||
"`mlir::FloatType::getBF16`\n"
|
||||
" - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
|
||||
" - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
|
||||
" - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
|
||||
" - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
|
||||
" - makeScalarType(c, \"i\", bitwidth) return an "
|
||||
"`mlir::Type::getInteger(bitwidth)`\n\n"
|
||||
"`mlir::IntegerType::get(bitwidth)`\n\n"
|
||||
" No other combinations are currently supported.")
|
||||
.def("make_memref_type", &PythonMLIRModule::makeMemRefType,
|
||||
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
|
||||
|
|
|
@ -572,6 +572,8 @@ called with the [`call_indirect` instruction](#'call_indirect'-operation).
|
|||
Function types are also used to indicate the arguments and results of
|
||||
[operations](#operations).
|
||||
|
||||
### Standard Types {#standard-types}
|
||||
|
||||
#### Index Type {#index-type}
|
||||
|
||||
Syntax:
|
||||
|
@ -590,10 +592,6 @@ used as an element of vector, tensor or memref type
|
|||
**Rationale:** integers of platform-specific bit widths are practical to express
|
||||
sizes, dimensionalities and subscripts.
|
||||
|
||||
TODO (Index type should not be a builtin).
|
||||
|
||||
### Standard Types {#standard-types}
|
||||
|
||||
#### Integer Type {#integer-type}
|
||||
|
||||
Syntax:
|
||||
|
|
|
@ -82,13 +82,13 @@ typedef struct {
|
|||
/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
|
||||
|
||||
/// Returns a simple scalar mlir::Type using the following convention:
|
||||
/// - makeScalarType(c, "bf16") return an `mlir::Type::getBF16`
|
||||
/// - makeScalarType(c, "f16") return an `mlir::Type::getF16`
|
||||
/// - makeScalarType(c, "f32") return an `mlir::Type::getF32`
|
||||
/// - makeScalarType(c, "f64") return an `mlir::Type::getF64`
|
||||
/// - makeScalarType(c, "index") return an `mlir::Type::getIndex`
|
||||
/// - makeScalarType(c, "bf16") return an `mlir::FloatType::getBF16`
|
||||
/// - makeScalarType(c, "f16") return an `mlir::FloatType::getF16`
|
||||
/// - makeScalarType(c, "f32") return an `mlir::FloatType::getF32`
|
||||
/// - makeScalarType(c, "f64") return an `mlir::FloatType::getF64`
|
||||
/// - makeScalarType(c, "index") return an `mlir::IndexType::get`
|
||||
/// - makeScalarType(c, "i", bitwidth) return an
|
||||
/// `mlir::Type::getInteger(bitwidth)`
|
||||
/// `mlir::IntegerType::get(bitwidth)`
|
||||
///
|
||||
/// No other combinations are currently supported.
|
||||
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
|
||||
|
|
|
@ -55,6 +55,9 @@ enum Kind {
|
|||
FIRST_FLOATING_POINT_TYPE = BF16,
|
||||
LAST_FLOATING_POINT_TYPE = F64,
|
||||
|
||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
||||
Index,
|
||||
|
||||
// Derived types.
|
||||
Integer,
|
||||
Vector,
|
||||
|
@ -70,6 +73,26 @@ inline bool Type::isF16() const { return getKind() == StandardTypes::F16; }
|
|||
inline bool Type::isF32() const { return getKind() == StandardTypes::F32; }
|
||||
inline bool Type::isF64() const { return getKind() == StandardTypes::F64; }
|
||||
|
||||
inline bool Type::isIndex() const { return getKind() == StandardTypes::Index; }
|
||||
|
||||
/// Index is a special integer-like type with unknown platform-dependent bit
|
||||
/// width.
|
||||
class IndexType : public Type::TypeBase<IndexType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Crete an IndexType instance, unique in the given context.
|
||||
static IndexType get(MLIRContext *context) {
|
||||
return Base::get(context, StandardTypes::Index);
|
||||
}
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
class IntegerType
|
||||
: public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
|
||||
|
@ -105,10 +128,6 @@ public:
|
|||
static constexpr unsigned kMaxWidth = 4096;
|
||||
};
|
||||
|
||||
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||
return IntegerType::get(width, ctx);
|
||||
}
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
inline bool Type::isInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
|
@ -137,6 +156,20 @@ public:
|
|||
return Base::get(context, kind);
|
||||
}
|
||||
|
||||
// Convenience factories.
|
||||
static FloatType getBF16(MLIRContext *ctx) {
|
||||
return get(StandardTypes::BF16, ctx);
|
||||
}
|
||||
static FloatType getF16(MLIRContext *ctx) {
|
||||
return get(StandardTypes::F16, ctx);
|
||||
}
|
||||
static FloatType getF32(MLIRContext *ctx) {
|
||||
return get(StandardTypes::F32, ctx);
|
||||
}
|
||||
static FloatType getF64(MLIRContext *ctx) {
|
||||
return get(StandardTypes::F64, ctx);
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
|
||||
|
@ -153,19 +186,6 @@ public:
|
|||
static char typeID;
|
||||
};
|
||||
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::BF16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF16(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF32(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F32, ctx);
|
||||
}
|
||||
inline FloatType Type::getF64(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F64, ctx);
|
||||
}
|
||||
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
/// types, because many operations work on values of these aggregate types.
|
||||
class VectorOrTensorType : public Type {
|
||||
|
|
|
@ -103,11 +103,7 @@ public:
|
|||
// Builtin types.
|
||||
Function,
|
||||
Unknown,
|
||||
|
||||
// TODO(riverriddle) Index shouldn't really be a builtin.
|
||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
||||
Index,
|
||||
LAST_BUILTIN_TYPE = Index,
|
||||
LAST_BUILTIN_TYPE = Unknown,
|
||||
|
||||
// Reserve type kinds for dialect specific type system extensions.
|
||||
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
|
||||
|
@ -225,14 +221,6 @@ public:
|
|||
/// Return true of this is an integer or a float type.
|
||||
bool isIntOrFloat() const;
|
||||
|
||||
// Convenience factories.
|
||||
static IndexType getIndex(MLIRContext *ctx);
|
||||
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
|
||||
static FloatType getBF16(MLIRContext *ctx);
|
||||
static FloatType getF16(MLIRContext *ctx);
|
||||
static FloatType getF32(MLIRContext *ctx);
|
||||
static FloatType getF64(MLIRContext *ctx);
|
||||
|
||||
/// Print the current type.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
@ -289,26 +277,6 @@ public:
|
|||
static char typeID;
|
||||
};
|
||||
|
||||
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
|
||||
|
||||
/// Index is special integer-like type with unknown platform-dependent bit width
|
||||
/// used in subscripts and loop induction variables.
|
||||
class IndexType : public Type::TypeBase<IndexType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Crete an IndexType instance, unique in the given context.
|
||||
static IndexType get(MLIRContext *context) {
|
||||
return Base::get(context, Kind::Index);
|
||||
}
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Index; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Unknown types represent types of non-registered dialects. These are types
|
||||
/// represented in their raw string form, and can only usefully be tested for
|
||||
/// type equality.
|
||||
|
@ -333,10 +301,6 @@ public:
|
|||
static char typeID;
|
||||
};
|
||||
|
||||
inline IndexType Type::getIndex(MLIRContext *ctx) {
|
||||
return IndexType::get(ctx);
|
||||
}
|
||||
|
||||
// Make Type hashable.
|
||||
inline ::llvm::hash_code hash_value(Type arg) {
|
||||
return ::llvm::hash_value(arg.type);
|
||||
|
|
|
@ -599,15 +599,18 @@ mlir_type_t makeScalarType(mlir_context_t context, const char *name,
|
|||
mlir_type_t res =
|
||||
llvm::StringSwitch<mlir_type_t>(name)
|
||||
.Case("bf16",
|
||||
mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()})
|
||||
.Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()})
|
||||
.Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()})
|
||||
.Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()})
|
||||
mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()})
|
||||
.Case("f16",
|
||||
mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()})
|
||||
.Case("f32",
|
||||
mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()})
|
||||
.Case("f64",
|
||||
mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()})
|
||||
.Case("index",
|
||||
mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()})
|
||||
mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()})
|
||||
.Case("i",
|
||||
mlir_type_t{
|
||||
mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()})
|
||||
mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()})
|
||||
.Default(mlir_type_t{nullptr});
|
||||
if (!res) {
|
||||
llvm_unreachable("Invalid type specifier");
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "AffineMapDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
@ -57,7 +57,7 @@ public:
|
|||
return constantFoldBinExpr(
|
||||
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
|
||||
case AffineExprKind::Constant:
|
||||
return IntegerAttr::get(Type::getIndex(expr.getContext()),
|
||||
return IntegerAttr::get(IndexType::get(expr.getContext()),
|
||||
expr.cast<AffineConstantExpr>().getValue());
|
||||
case AffineExprKind::DimId:
|
||||
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
|
||||
|
|
|
@ -720,7 +720,7 @@ void ModulePrinter::printType(Type type) {
|
|||
<< unknownTy.getTypeData() << "\">";
|
||||
return;
|
||||
}
|
||||
case Type::Kind::Index:
|
||||
case StandardTypes::Index:
|
||||
os << "index";
|
||||
return;
|
||||
case StandardTypes::BF16:
|
||||
|
|
|
@ -56,20 +56,20 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
|
|||
// Types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FloatType Builder::getBF16Type() { return Type::getBF16(context); }
|
||||
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
|
||||
|
||||
FloatType Builder::getF16Type() { return Type::getF16(context); }
|
||||
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
|
||||
|
||||
FloatType Builder::getF32Type() { return Type::getF32(context); }
|
||||
FloatType Builder::getF32Type() { return FloatType::getF32(context); }
|
||||
|
||||
FloatType Builder::getF64Type() { return Type::getF64(context); }
|
||||
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
|
||||
|
||||
IndexType Builder::getIndexType() { return Type::getIndex(context); }
|
||||
IndexType Builder::getIndexType() { return IndexType::get(context); }
|
||||
|
||||
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
|
||||
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
|
||||
|
||||
IntegerType Builder::getIntegerType(unsigned width) {
|
||||
return Type::getInteger(width, context);
|
||||
return IntegerType::get(width, context);
|
||||
}
|
||||
|
||||
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
|
||||
|
|
|
@ -38,7 +38,7 @@ using namespace mlir;
|
|||
BuiltinDialect::BuiltinDialect(MLIRContext *context)
|
||||
: Dialect(/*namePrefix=*/"", context) {
|
||||
addOperations<BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
|
||||
addTypes<FunctionType, IndexType, UnknownType, FloatType, IntegerType,
|
||||
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
|
||||
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
|
||||
}
|
||||
|
||||
|
|
|
@ -369,6 +369,7 @@ unsigned MemRefType::getNumDynamicDims() const {
|
|||
|
||||
// Define type identifiers.
|
||||
char FloatType::typeID = 0;
|
||||
char IndexType::typeID = 0;
|
||||
char IntegerType::typeID = 0;
|
||||
char VectorType::typeID = 0;
|
||||
char RankedTensorType::typeID = 0;
|
||||
|
|
|
@ -70,5 +70,4 @@ StringRef UnknownType::getTypeData() const {
|
|||
|
||||
// Define type identifiers.
|
||||
char FunctionType::typeID = 0;
|
||||
char IndexType::typeID = 0;
|
||||
char UnknownType::typeID = 0;
|
||||
|
|
|
@ -824,7 +824,7 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
|||
}
|
||||
|
||||
if (indexSize >= 0)
|
||||
return IntegerAttr::get(Type::getIndex(context), indexSize);
|
||||
return IntegerAttr::get(IndexType::get(context), indexSize);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -721,7 +721,7 @@ static bool materialize(Function *f,
|
|||
// Set scoped super-vector and corresponding hw vector types.
|
||||
state->superVectorType = terminator->getVectorType();
|
||||
assert((state->superVectorType.getElementType() ==
|
||||
Type::getF32(term->getContext())) &&
|
||||
FloatType::getF32(term->getContext())) &&
|
||||
"Only f32 supported for now");
|
||||
state->hwVectorType = VectorType::get(
|
||||
state->hwVectorSize, state->superVectorType.getElementType());
|
||||
|
@ -751,7 +751,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
|
|||
// Get the hardware vector type.
|
||||
// TODO(ntv): get elemental type from super-vector type rather than force f32.
|
||||
auto subVectorType =
|
||||
VectorType::get(state.hwVectorSize, Type::getF32(f->getContext()));
|
||||
VectorType::get(state.hwVectorSize, FloatType::getF32(f->getContext()));
|
||||
|
||||
// Capture terminators; i.e. vector_transfer_write ops involving a strict
|
||||
// super-vector of subVectorType.
|
||||
|
|
|
@ -107,12 +107,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
|
|||
using matcher::Op;
|
||||
SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
|
||||
clTestVectorShapeRatio.end());
|
||||
auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext()));
|
||||
auto subVectorType =
|
||||
VectorType::get(shape, FloatType::getF32(f->getContext()));
|
||||
// Only filter instructions that operate on a strict super-vector and have one
|
||||
// return. This makes testing easier.
|
||||
auto filter = [subVectorType](const Instruction &inst) {
|
||||
assert(subVectorType.getElementType() ==
|
||||
Type::getF32(subVectorType.getContext()) &&
|
||||
FloatType::getF32(subVectorType.getContext()) &&
|
||||
"Only f32 supported for now");
|
||||
if (!matcher::operatesOnSuperVectors(inst, subVectorType)) {
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue