Make IndexType a standard type instead of a builtin. This also cleans up some unnecessary factory methods on the Type class.

PiperOrigin-RevId: 233640730
This commit is contained in:
River Riddle 2019-02-12 11:08:04 -08:00 committed by jpienaar
parent 8de7f6c471
commit 4755774d16
15 changed files with 80 additions and 93 deletions

View File

@ -434,13 +434,14 @@ PYBIND11_MODULE(pybind, m) {
},
py::arg("type"), py::arg("bitwidth") = 0,
"Returns a scalar mlir::Type using the following convention:\n"
" - makeScalarType(c, \"bf16\") return an `mlir::Type::getBF16`\n"
" - makeScalarType(c, \"f16\") return an `mlir::Type::getF16`\n"
" - makeScalarType(c, \"f32\") return an `mlir::Type::getF32`\n"
" - makeScalarType(c, \"f64\") return an `mlir::Type::getF64`\n"
" - makeScalarType(c, \"index\") return an `mlir::Type::getIndex`\n"
" - makeScalarType(c, \"bf16\") return an "
"`mlir::FloatType::getBF16`\n"
" - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
" - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
" - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
" - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
" - makeScalarType(c, \"i\", bitwidth) return an "
"`mlir::Type::getInteger(bitwidth)`\n\n"
"`mlir::IntegerType::get(bitwidth)`\n\n"
" No other combinations are currently supported.")
.def("make_memref_type", &PythonMLIRModule::makeMemRefType,
"Returns an mlir::MemRefType of an elemental scalar. -1 is used to "

View File

@ -572,6 +572,8 @@ called with the [`call_indirect` instruction](#'call_indirect'-operation).
Function types are also used to indicate the arguments and results of
[operations](#operations).
### Standard Types {#standard-types}
#### Index Type {#index-type}
Syntax:
@ -590,10 +592,6 @@ used as an element of vector, tensor or memref type
**Rationale:** integers of platform-specific bit widths are practical to express
sizes, dimensionalities and subscripts.
TODO (Index type should not be a builtin).
### Standard Types {#standard-types}
#### Integer Type {#integer-type}
Syntax:

View File

@ -82,13 +82,13 @@ typedef struct {
/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
/// Returns a simple scalar mlir::Type using the following convention:
/// - makeScalarType(c, "bf16") return an `mlir::Type::getBF16`
/// - makeScalarType(c, "f16") return an `mlir::Type::getF16`
/// - makeScalarType(c, "f32") return an `mlir::Type::getF32`
/// - makeScalarType(c, "f64") return an `mlir::Type::getF64`
/// - makeScalarType(c, "index") return an `mlir::Type::getIndex`
/// - makeScalarType(c, "bf16") return an `mlir::FloatType::getBF16`
/// - makeScalarType(c, "f16") return an `mlir::FloatType::getF16`
/// - makeScalarType(c, "f32") return an `mlir::FloatType::getF32`
/// - makeScalarType(c, "f64") return an `mlir::FloatType::getF64`
/// - makeScalarType(c, "index") return an `mlir::IndexType::get`
/// - makeScalarType(c, "i", bitwidth) return an
/// `mlir::Type::getInteger(bitwidth)`
/// `mlir::IntegerType::get(bitwidth)`
///
/// No other combinations are currently supported.
mlir_type_t makeScalarType(mlir_context_t context, const char *name,

View File

@ -55,6 +55,9 @@ enum Kind {
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
// Derived types.
Integer,
Vector,
@ -70,6 +73,26 @@ inline bool Type::isF16() const { return getKind() == StandardTypes::F16; }
inline bool Type::isF32() const { return getKind() == StandardTypes::F32; }
inline bool Type::isF64() const { return getKind() == StandardTypes::F64; }
inline bool Type::isIndex() const { return getKind() == StandardTypes::Index; }
/// Index is a special integer-like type with unknown platform-dependent bit
/// width.
class IndexType : public Type::TypeBase<IndexType, Type> {
public:
using Base::Base;
/// Crete an IndexType instance, unique in the given context.
static IndexType get(MLIRContext *context) {
return Base::get(context, StandardTypes::Index);
}
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
/// Unique identifier for this type class.
static char typeID;
};
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType
: public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
@ -105,10 +128,6 @@ public:
static constexpr unsigned kMaxWidth = 4096;
};
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
/// Return true if this is an integer type with the specified width.
inline bool Type::isInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
@ -137,6 +156,20 @@ public:
return Base::get(context, kind);
}
// Convenience factories.
static FloatType getBF16(MLIRContext *ctx) {
return get(StandardTypes::BF16, ctx);
}
static FloatType getF16(MLIRContext *ctx) {
return get(StandardTypes::F16, ctx);
}
static FloatType getF32(MLIRContext *ctx) {
return get(StandardTypes::F32, ctx);
}
static FloatType getF64(MLIRContext *ctx) {
return get(StandardTypes::F64, ctx);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
@ -153,19 +186,6 @@ public:
static char typeID;
};
inline FloatType Type::getBF16(MLIRContext *ctx) {
return FloatType::get(StandardTypes::BF16, ctx);
}
inline FloatType Type::getF16(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F16, ctx);
}
inline FloatType Type::getF32(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F32, ctx);
}
inline FloatType Type::getF64(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F64, ctx);
}
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
/// types, because many operations work on values of these aggregate types.
class VectorOrTensorType : public Type {

View File

@ -103,11 +103,7 @@ public:
// Builtin types.
Function,
Unknown,
// TODO(riverriddle) Index shouldn't really be a builtin.
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
LAST_BUILTIN_TYPE = Index,
LAST_BUILTIN_TYPE = Unknown,
// Reserve type kinds for dialect specific type system extensions.
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
@ -225,14 +221,6 @@ public:
/// Return true of this is an integer or a float type.
bool isIntOrFloat() const;
// Convenience factories.
static IndexType getIndex(MLIRContext *ctx);
static IntegerType getInteger(unsigned width, MLIRContext *ctx);
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
void dump() const;
@ -289,26 +277,6 @@ public:
static char typeID;
};
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
/// Index is special integer-like type with unknown platform-dependent bit width
/// used in subscripts and loop induction variables.
class IndexType : public Type::TypeBase<IndexType, Type> {
public:
using Base::Base;
/// Crete an IndexType instance, unique in the given context.
static IndexType get(MLIRContext *context) {
return Base::get(context, Kind::Index);
}
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == Kind::Index; }
/// Unique identifier for this type class.
static char typeID;
};
/// Unknown types represent types of non-registered dialects. These are types
/// represented in their raw string form, and can only usefully be tested for
/// type equality.
@ -333,10 +301,6 @@ public:
static char typeID;
};
inline IndexType Type::getIndex(MLIRContext *ctx) {
return IndexType::get(ctx);
}
// Make Type hashable.
inline ::llvm::hash_code hash_value(Type arg) {
return ::llvm::hash_value(arg.type);

View File

@ -599,15 +599,18 @@ mlir_type_t makeScalarType(mlir_context_t context, const char *name,
mlir_type_t res =
llvm::StringSwitch<mlir_type_t>(name)
.Case("bf16",
mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()})
.Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()})
.Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()})
.Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()})
mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()})
.Case("f16",
mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()})
.Case("f32",
mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()})
.Case("f64",
mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()})
.Case("index",
mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()})
mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()})
.Case("i",
mlir_type_t{
mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()})
mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()})
.Default(mlir_type_t{nullptr});
if (!res) {
llvm_unreachable("Invalid type specifier");

View File

@ -19,7 +19,7 @@
#include "AffineMapDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringRef.h"
@ -57,7 +57,7 @@ public:
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
case AffineExprKind::Constant:
return IntegerAttr::get(Type::getIndex(expr.getContext()),
return IntegerAttr::get(IndexType::get(expr.getContext()),
expr.cast<AffineConstantExpr>().getValue());
case AffineExprKind::DimId:
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]

View File

@ -720,7 +720,7 @@ void ModulePrinter::printType(Type type) {
<< unknownTy.getTypeData() << "\">";
return;
}
case Type::Kind::Index:
case StandardTypes::Index:
os << "index";
return;
case StandardTypes::BF16:

View File

@ -56,20 +56,20 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
FloatType Builder::getBF16Type() { return Type::getBF16(context); }
FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
FloatType Builder::getF16Type() { return Type::getF16(context); }
FloatType Builder::getF16Type() { return FloatType::getF16(context); }
FloatType Builder::getF32Type() { return Type::getF32(context); }
FloatType Builder::getF32Type() { return FloatType::getF32(context); }
FloatType Builder::getF64Type() { return Type::getF64(context); }
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
IndexType Builder::getIndexType() { return Type::getIndex(context); }
IndexType Builder::getIndexType() { return IndexType::get(context); }
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
return IntegerType::get(width, context);
}
FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,

View File

@ -38,7 +38,7 @@ using namespace mlir;
BuiltinDialect::BuiltinDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
addTypes<FunctionType, IndexType, UnknownType, FloatType, IntegerType,
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
}

View File

@ -369,6 +369,7 @@ unsigned MemRefType::getNumDynamicDims() const {
// Define type identifiers.
char FloatType::typeID = 0;
char IndexType::typeID = 0;
char IntegerType::typeID = 0;
char VectorType::typeID = 0;
char RankedTensorType::typeID = 0;

View File

@ -70,5 +70,4 @@ StringRef UnknownType::getTypeData() const {
// Define type identifiers.
char FunctionType::typeID = 0;
char IndexType::typeID = 0;
char UnknownType::typeID = 0;

View File

@ -824,7 +824,7 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
}
if (indexSize >= 0)
return IntegerAttr::get(Type::getIndex(context), indexSize);
return IntegerAttr::get(IndexType::get(context), indexSize);
return nullptr;
}

View File

@ -721,7 +721,7 @@ static bool materialize(Function *f,
// Set scoped super-vector and corresponding hw vector types.
state->superVectorType = terminator->getVectorType();
assert((state->superVectorType.getElementType() ==
Type::getF32(term->getContext())) &&
FloatType::getF32(term->getContext())) &&
"Only f32 supported for now");
state->hwVectorType = VectorType::get(
state->hwVectorSize, state->superVectorType.getElementType());
@ -751,7 +751,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
// Get the hardware vector type.
// TODO(ntv): get elemental type from super-vector type rather than force f32.
auto subVectorType =
VectorType::get(state.hwVectorSize, Type::getF32(f->getContext()));
VectorType::get(state.hwVectorSize, FloatType::getF32(f->getContext()));
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.

View File

@ -107,12 +107,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
using matcher::Op;
SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
clTestVectorShapeRatio.end());
auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext()));
auto subVectorType =
VectorType::get(shape, FloatType::getF32(f->getContext()));
// Only filter instructions that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [subVectorType](const Instruction &inst) {
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
FloatType::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
if (!matcher::operatesOnSuperVectors(inst, subVectorType)) {
return false;