Implement basic IR support for a builtin complex<> type. As with tuples, we

have no standard ops for working with these yet, this is simply enough to
    represent and round trip them in the printer and parser.

--

PiperOrigin-RevId: 241102728
This commit is contained in:
Chris Lattner 2019-03-29 22:23:34 -07:00 committed by Mehdi Amini
parent 1273af232c
commit 0fb905c070
10 changed files with 184 additions and 15 deletions

View File

@ -512,6 +512,8 @@ non-function-type ::= integer-type
| memref-type | memref-type
| dialect-type | dialect-type
| type-alias | type-alias
| complex-type
| tuple-type
type-list-no-parens ::= type (`,` type)* type-list-no-parens ::= type (`,` type)*
type-list-parens ::= `(` `)` type-list-parens ::= `(` `)`
@ -860,6 +862,25 @@ access pattern analysis, and for performance optimizations like vectorization,
copy elision and in-place updates. If an affine map composition is not specified copy elision and in-place updates. If an affine map composition is not specified
for the memref, the identity affine map is assumed. for the memref, the identity affine map is assumed.
#### Complex Type {#complex-type}
Syntax:
``` {.ebnf}
complex-type ::= `complex` `<` type `>`
```
The value of `complex` type represents a complex number with a parameterized
element type, which is composed of a real and imaginary value of that element
type. The element must be a floating point or integer scalar type.
Examples:
```mlir {.mlir}
complex<f32>
complex<i32>
```
#### Tuple Type {#tuple-type} #### Tuple Type {#tuple-type}
Syntax: Syntax:

View File

@ -41,6 +41,7 @@ struct VectorTypeStorage;
struct RankedTensorTypeStorage; struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage; struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage; struct MemRefTypeStorage;
struct ComplexTypeStorage;
struct TupleTypeStorage; struct TupleTypeStorage;
} // namespace detail } // namespace detail
@ -64,6 +65,7 @@ enum Kind {
RankedTensor, RankedTensor,
UnrankedTensor, UnrankedTensor,
MemRef, MemRef,
Complex,
Tuple, Tuple,
}; };
@ -421,6 +423,38 @@ private:
unsigned memorySpace, Optional<Location> location); unsigned memorySpace, Optional<Location> location);
}; };
/// The 'complex' type represents a complex number with a parameterized element
/// type, which is composed of a real and imaginary value of that element type.
///
/// The element must be a floating point or integer scalar type.
///
class ComplexType
: public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
public:
using Base::Base;
/// Get or create a ComplexType with the provided element type.
static ComplexType get(Type elementType);
/// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type
/// isn't supported.
static ComplexType getChecked(Type elementType, Location location);
/// Verify the construction of an integer type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Type elementType);
Type getElementType();
static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
private:
static ComplexType getCheckedImpl(Type elementType,
Optional<Location> location);
};
/// Tuple types represent a collection of other types. Note: This type merely /// Tuple types represent a collection of other types. Note: This type merely
/// provides a common mechanism for representing tuples in MLIR. It is up to /// provides a common mechanism for representing tuples in MLIR. It is up to
/// dialect authors to provides operations for manipulating them, e.g. /// dialect authors to provides operations for manipulating them, e.g.

View File

@ -804,6 +804,11 @@ void ModulePrinter::printType(Type type) {
os << '>'; os << '>';
return; return;
} }
case StandardTypes::Complex:
os << "complex<";
printType(type.cast<ComplexType>().getElementType());
os << '>';
return;
case StandardTypes::Tuple: { case StandardTypes::Tuple: {
auto tuple = type.cast<TupleType>(); auto tuple = type.cast<TupleType>();
os << "tuple<"; os << "tuple<";

View File

@ -106,7 +106,7 @@ struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) {
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType, addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
VectorType, RankedTensorType, UnrankedTensorType, MemRefType, VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
TupleType>(); ComplexType, TupleType>();
} }
}; };

View File

@ -26,7 +26,9 @@
using namespace mlir; using namespace mlir;
using namespace mlir::detail; using namespace mlir::detail;
/// Integer Type. //===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
LogicalResult IntegerType::verifyConstructionInvariants( LogicalResult IntegerType::verifyConstructionInvariants(
@ -51,7 +53,9 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
unsigned IntegerType::getWidth() const { return getImpl()->width; } unsigned IntegerType::getWidth() const { return getImpl()->width; }
/// Float Type. //===----------------------------------------------------------------------===//
// Float Type
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() const { unsigned FloatType::getWidth() const {
switch (getKind()) { switch (getKind()) {
@ -95,7 +99,9 @@ unsigned Type::getIntOrFloatBitWidth() const {
return floatType.getWidth(); return floatType.getWidth();
} }
/// VectorOrTensorType //===----------------------------------------------------------------------===//
// VectorOrTensorType
//===----------------------------------------------------------------------===//
Type VectorOrTensorType::getElementType() const { Type VectorOrTensorType::getElementType() const {
return static_cast<ImplType *>(type)->elementType; return static_cast<ImplType *>(type)->elementType;
@ -180,7 +186,9 @@ bool VectorOrTensorType::hasStaticShape() const {
return llvm::none_of(getShape(), [](int64_t i) { return i < 0; }); return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
} }
/// VectorType //===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::Vector, shape, return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
@ -219,7 +227,9 @@ LogicalResult VectorType::verifyConstructionInvariants(
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
/// TensorType //===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
// Check if "elementType" can be an element type of a tensor. Emit errors if // Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed. // location is not nullptr. Returns failure if check failed.
@ -234,7 +244,9 @@ static inline LogicalResult checkTensorElementType(Optional<Location> location,
return success(); return success();
} }
/// RankedTensorType //===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
Type elementType) { Type elementType) {
@ -266,7 +278,9 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
return getImpl()->getShape(); return getImpl()->getShape();
} }
/// UnrankedTensorType //===----------------------------------------------------------------------===//
// UnrankedTensorType
//===----------------------------------------------------------------------===//
UnrankedTensorType UnrankedTensorType::get(Type elementType) { UnrankedTensorType UnrankedTensorType::get(Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor, return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
@ -284,7 +298,9 @@ LogicalResult UnrankedTensorType::verifyConstructionInvariants(
return checkTensorElementType(loc, context, elementType); return checkTensorElementType(loc, context, elementType);
} }
/// MemRefType //===----------------------------------------------------------------------===//
// MemRefType
//===----------------------------------------------------------------------===//
/// Get or create a new MemRefType defined by the arguments. If the resulting /// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided, /// type would be ill-formed, return nullptr. If the location is provided,
@ -313,13 +329,12 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
for (const auto &affineMap : affineMapComposition) { for (const auto &affineMap : affineMapComposition) {
if (affineMap.getNumDims() != dim) { if (affineMap.getNumDims() != dim) {
if (location) if (location)
context->emitDiagnostic( context->emitError(
*location, *location,
"memref affine map dimension mismatch between " + "memref affine map dimension mismatch between " +
(i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) + (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) +
" and affine map" + Twine(i + 1) + ": " + Twine(dim) + " and affine map" + Twine(i + 1) + ": " + Twine(dim) +
" != " + Twine(affineMap.getNumDims()), " != " + Twine(affineMap.getNumDims()));
MLIRContext::DiagnosticKind::Error);
return nullptr; return nullptr;
} }
@ -361,7 +376,36 @@ unsigned MemRefType::getNumDynamicDims() const {
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; }); return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
} }
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
ComplexType ComplexType::get(Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::Complex,
elementType);
}
ComplexType ComplexType::getChecked(Type elementType, Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::Complex, elementType);
}
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
if (loc)
context->emitError(*loc, "invalid element type for complex");
return failure();
}
return success();
}
Type ComplexType::getElementType() { return getImpl()->elementType; }
//===----------------------------------------------------------------------===//
/// TupleType /// TupleType
//===----------------------------------------------------------------------===//
/// Get or create a new TupleType with the provided element types. Assumes the /// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type. /// arguments define a well-formed type.

View File

@ -256,6 +256,24 @@ struct MemRefTypeStorage : public TypeStorage {
const unsigned memorySpace; const unsigned memorySpace;
}; };
/// Complex Type Storage.
struct ComplexTypeStorage : public TypeStorage {
ComplexTypeStorage(Type elementType) : elementType(elementType) {}
/// The hash key used for uniquing.
using KeyTy = Type;
bool operator==(const KeyTy &key) const { return key == elementType; }
/// Construction.
static ComplexTypeStorage *construct(TypeStorageAllocator &allocator,
Type elementType) {
return new (allocator.allocate<ComplexTypeStorage>())
ComplexTypeStorage(elementType);
}
Type elementType;
};
/// A type representing a collection of other types. /// A type representing a collection of other types.
struct TupleTypeStorage final struct TupleTypeStorage final
: public TypeStorage, : public TypeStorage,
@ -266,7 +284,7 @@ struct TupleTypeStorage final
/// Construction. /// Construction.
static TupleTypeStorage *construct(TypeStorageAllocator &allocator, static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
const ArrayRef<Type> &key) { ArrayRef<Type> key) {
// Allocate a new storage instance. // Allocate a new storage instance.
auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size()); auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage)); auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));

View File

@ -185,6 +185,7 @@ public:
bool allowDynamic); bool allowDynamic);
Type parseExtendedType(); Type parseExtendedType();
Type parseTensorType(); Type parseTensorType();
Type parseComplexType();
Type parseTupleType(); Type parseTupleType();
Type parseMemRefType(); Type parseMemRefType();
Type parseFunctionType(); Type parseFunctionType();
@ -320,6 +321,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
/// | vector-type /// | vector-type
/// | tensor-type /// | tensor-type
/// | memref-type /// | memref-type
/// | complex-type
/// | tuple-type /// | tuple-type
/// ///
/// index-type ::= `index` /// index-type ::= `index`
@ -333,6 +335,8 @@ Type Parser::parseNonFunctionType() {
return parseMemRefType(); return parseMemRefType();
case Token::kw_tensor: case Token::kw_tensor:
return parseTensorType(); return parseTensorType();
case Token::kw_complex:
return parseComplexType();
case Token::kw_tuple: case Token::kw_tuple:
return parseTupleType(); return parseTupleType();
case Token::kw_vector: case Token::kw_vector:
@ -571,6 +575,26 @@ Type Parser::parseTensorType() {
return RankedTensorType::getChecked(dimensions, elementType, typeLocation); return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
} }
/// Parse a complex type.
///
/// complex-type ::= `complex` `<` type `>`
///
Type Parser::parseComplexType() {
consumeToken(Token::kw_complex);
// Parse the '<'.
if (parseToken(Token::less, "expected '<' in complex type"))
return nullptr;
auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
auto elementType = parseType();
if (!elementType ||
parseToken(Token::greater, "expected '>' in complex type"))
return nullptr;
return ComplexType::getChecked(elementType, typeLocation);
}
/// Parse a tuple type. /// Parse a tuple type.
/// ///
/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`

View File

@ -87,11 +87,14 @@ TOK_OPERATOR(star, "*")
// TODO: More operator tokens // TODO: More operator tokens
// Keywords. These turn "foo" into Token::kw_foo enums. // Keywords. These turn "foo" into Token::kw_foo enums.
// NOTE: Please key these alphabetized to make it easier to find something in
// this list and to cater to OCD.
TOK_KEYWORD(attributes) TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16) TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv) TOK_KEYWORD(ceildiv)
TOK_KEYWORD(complex)
TOK_KEYWORD(dense) TOK_KEYWORD(dense)
TOK_KEYWORD(splat)
TOK_KEYWORD(f16) TOK_KEYWORD(f16)
TOK_KEYWORD(f32) TOK_KEYWORD(f32)
TOK_KEYWORD(f64) TOK_KEYWORD(f64)
@ -107,13 +110,14 @@ TOK_KEYWORD(min)
TOK_KEYWORD(mod) TOK_KEYWORD(mod)
TOK_KEYWORD(opaque) TOK_KEYWORD(opaque)
TOK_KEYWORD(size) TOK_KEYWORD(size)
TOK_KEYWORD(sparse)
TOK_KEYWORD(splat)
TOK_KEYWORD(step) TOK_KEYWORD(step)
TOK_KEYWORD(tensor) TOK_KEYWORD(tensor)
TOK_KEYWORD(to) TOK_KEYWORD(to)
TOK_KEYWORD(true) TOK_KEYWORD(true)
TOK_KEYWORD(tuple) TOK_KEYWORD(tuple)
TOK_KEYWORD(type) TOK_KEYWORD(type)
TOK_KEYWORD(sparse)
TOK_KEYWORD(vector) TOK_KEYWORD(vector)
#undef TOK_MARKER #undef TOK_MARKER

View File

@ -1059,3 +1059,18 @@ func @ssa_name_missing_eq() {
%0:2 "foo" () : () -> (i32, i32) %0:2 "foo" () : () -> (i32, i32)
return return
} }
// -----
// expected-error @+1 {{invalid element type for complex}}
func @bad_complex(complex<memref<2x4xi8>>)
// -----
// expected-error @+1 {{expected '<' in complex type}}
func @bad_complex(complex memref<2x4xi8>>)
// -----
// expected-error @+1 {{expected '>' in complex type}}
func @bad_complex(complex<i32)

View File

@ -128,6 +128,10 @@ func @memrefs_drop_triv_id_multiple(memref<2xi8, (d0) -> (d0), (d0) -> (d0)>)
func @memrefs_compose_with_id(memref<2x2xi8, (d0, d1) -> (d0, d1), func @memrefs_compose_with_id(memref<2x2xi8, (d0, d1) -> (d0, d1),
(d0, d1) -> (d1, d0)>) (d0, d1) -> (d1, d0)>)
// CHECK: func @complex_types(complex<i1>) -> complex<f32>
func @complex_types(complex<i1>) -> complex<f32>
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ()) // CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->()) func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())