forked from OSchip/llvm-project
Implement basic IR support for a builtin complex<> type. As with tuples, we
have no standard ops for working with these yet, this is simply enough to represent and round trip them in the printer and parser. -- PiperOrigin-RevId: 241102728
This commit is contained in:
parent
1273af232c
commit
0fb905c070
|
@ -512,6 +512,8 @@ non-function-type ::= integer-type
|
|||
| memref-type
|
||||
| dialect-type
|
||||
| type-alias
|
||||
| complex-type
|
||||
| tuple-type
|
||||
|
||||
type-list-no-parens ::= type (`,` type)*
|
||||
type-list-parens ::= `(` `)`
|
||||
|
@ -860,6 +862,25 @@ access pattern analysis, and for performance optimizations like vectorization,
|
|||
copy elision and in-place updates. If an affine map composition is not specified
|
||||
for the memref, the identity affine map is assumed.
|
||||
|
||||
#### Complex Type {#complex-type}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
complex-type ::= `complex` `<` type `>`
|
||||
```
|
||||
|
||||
The value of `complex` type represents a complex number with a parameterized
|
||||
element type, which is composed of a real and imaginary value of that element
|
||||
type. The element must be a floating point or integer scalar type.
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
complex<f32>
|
||||
complex<i32>
|
||||
```
|
||||
|
||||
#### Tuple Type {#tuple-type}
|
||||
|
||||
Syntax:
|
||||
|
|
|
@ -41,6 +41,7 @@ struct VectorTypeStorage;
|
|||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
struct ComplexTypeStorage;
|
||||
struct TupleTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
@ -64,6 +65,7 @@ enum Kind {
|
|||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
Complex,
|
||||
Tuple,
|
||||
};
|
||||
|
||||
|
@ -421,6 +423,38 @@ private:
|
|||
unsigned memorySpace, Optional<Location> location);
|
||||
};
|
||||
|
||||
/// The 'complex' type represents a complex number with a parameterized element
|
||||
/// type, which is composed of a real and imaginary value of that element type.
|
||||
///
|
||||
/// The element must be a floating point or integer scalar type.
|
||||
///
|
||||
class ComplexType
|
||||
: public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create a ComplexType with the provided element type.
|
||||
static ComplexType get(Type elementType);
|
||||
|
||||
/// Get or create a ComplexType with the provided element type. This emits
|
||||
/// and error at the specified location and returns null if the element type
|
||||
/// isn't supported.
|
||||
static ComplexType getChecked(Type elementType, Location location);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Type elementType);
|
||||
|
||||
Type getElementType();
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
|
||||
|
||||
private:
|
||||
static ComplexType getCheckedImpl(Type elementType,
|
||||
Optional<Location> location);
|
||||
};
|
||||
|
||||
/// Tuple types represent a collection of other types. Note: This type merely
|
||||
/// provides a common mechanism for representing tuples in MLIR. It is up to
|
||||
/// dialect authors to provides operations for manipulating them, e.g.
|
||||
|
|
|
@ -804,6 +804,11 @@ void ModulePrinter::printType(Type type) {
|
|||
os << '>';
|
||||
return;
|
||||
}
|
||||
case StandardTypes::Complex:
|
||||
os << "complex<";
|
||||
printType(type.cast<ComplexType>().getElementType());
|
||||
os << '>';
|
||||
return;
|
||||
case StandardTypes::Tuple: {
|
||||
auto tuple = type.cast<TupleType>();
|
||||
os << "tuple<";
|
||||
|
|
|
@ -106,7 +106,7 @@ struct BuiltinDialect : public Dialect {
|
|||
BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) {
|
||||
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
|
||||
VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
|
||||
TupleType>();
|
||||
ComplexType, TupleType>();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -26,7 +26,9 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
/// Integer Type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
LogicalResult IntegerType::verifyConstructionInvariants(
|
||||
|
@ -51,7 +53,9 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
|
|||
|
||||
unsigned IntegerType::getWidth() const { return getImpl()->width; }
|
||||
|
||||
/// Float Type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Float Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned FloatType::getWidth() const {
|
||||
switch (getKind()) {
|
||||
|
@ -95,7 +99,9 @@ unsigned Type::getIntOrFloatBitWidth() const {
|
|||
return floatType.getWidth();
|
||||
}
|
||||
|
||||
/// VectorOrTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorOrTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type VectorOrTensorType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
|
@ -180,7 +186,9 @@ bool VectorOrTensorType::hasStaticShape() const {
|
|||
return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
|
||||
}
|
||||
|
||||
/// VectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
|
||||
|
@ -219,7 +227,9 @@ LogicalResult VectorType::verifyConstructionInvariants(
|
|||
|
||||
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
|
||||
|
||||
/// TensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Check if "elementType" can be an element type of a tensor. Emit errors if
|
||||
// location is not nullptr. Returns failure if check failed.
|
||||
|
@ -234,7 +244,9 @@ static inline LogicalResult checkTensorElementType(Optional<Location> location,
|
|||
return success();
|
||||
}
|
||||
|
||||
/// RankedTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankedTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
|
@ -266,7 +278,9 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
|
|||
return getImpl()->getShape();
|
||||
}
|
||||
|
||||
/// UnrankedTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnrankedTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
|
||||
|
@ -284,7 +298,9 @@ LogicalResult UnrankedTensorType::verifyConstructionInvariants(
|
|||
return checkTensorElementType(loc, context, elementType);
|
||||
}
|
||||
|
||||
/// MemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Get or create a new MemRefType defined by the arguments. If the resulting
|
||||
/// type would be ill-formed, return nullptr. If the location is provided,
|
||||
|
@ -313,13 +329,12 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
|||
for (const auto &affineMap : affineMapComposition) {
|
||||
if (affineMap.getNumDims() != dim) {
|
||||
if (location)
|
||||
context->emitDiagnostic(
|
||||
context->emitError(
|
||||
*location,
|
||||
"memref affine map dimension mismatch between " +
|
||||
(i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) +
|
||||
" and affine map" + Twine(i + 1) + ": " + Twine(dim) +
|
||||
" != " + Twine(affineMap.getNumDims()),
|
||||
MLIRContext::DiagnosticKind::Error);
|
||||
" != " + Twine(affineMap.getNumDims()));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -361,7 +376,36 @@ unsigned MemRefType::getNumDynamicDims() const {
|
|||
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ComplexType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ComplexType ComplexType::get(Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::Complex,
|
||||
elementType);
|
||||
}
|
||||
|
||||
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
||||
return Base::getChecked(location, elementType.getContext(),
|
||||
StandardTypes::Complex, elementType);
|
||||
}
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
LogicalResult ComplexType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
|
||||
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "invalid element type for complex");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Type ComplexType::getElementType() { return getImpl()->elementType; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// TupleType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
|
|
|
@ -256,6 +256,24 @@ struct MemRefTypeStorage : public TypeStorage {
|
|||
const unsigned memorySpace;
|
||||
};
|
||||
|
||||
/// Complex Type Storage.
|
||||
struct ComplexTypeStorage : public TypeStorage {
|
||||
ComplexTypeStorage(Type elementType) : elementType(elementType) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = Type;
|
||||
bool operator==(const KeyTy &key) const { return key == elementType; }
|
||||
|
||||
/// Construction.
|
||||
static ComplexTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
Type elementType) {
|
||||
return new (allocator.allocate<ComplexTypeStorage>())
|
||||
ComplexTypeStorage(elementType);
|
||||
}
|
||||
|
||||
Type elementType;
|
||||
};
|
||||
|
||||
/// A type representing a collection of other types.
|
||||
struct TupleTypeStorage final
|
||||
: public TypeStorage,
|
||||
|
@ -266,7 +284,7 @@ struct TupleTypeStorage final
|
|||
|
||||
/// Construction.
|
||||
static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const ArrayRef<Type> &key) {
|
||||
ArrayRef<Type> key) {
|
||||
// Allocate a new storage instance.
|
||||
auto byteSize = TupleTypeStorage::totalSizeToAlloc<Type>(key.size());
|
||||
auto rawMem = allocator.allocate(byteSize, alignof(TupleTypeStorage));
|
||||
|
|
|
@ -185,6 +185,7 @@ public:
|
|||
bool allowDynamic);
|
||||
Type parseExtendedType();
|
||||
Type parseTensorType();
|
||||
Type parseComplexType();
|
||||
Type parseTupleType();
|
||||
Type parseMemRefType();
|
||||
Type parseFunctionType();
|
||||
|
@ -320,6 +321,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
|||
/// | vector-type
|
||||
/// | tensor-type
|
||||
/// | memref-type
|
||||
/// | complex-type
|
||||
/// | tuple-type
|
||||
///
|
||||
/// index-type ::= `index`
|
||||
|
@ -333,6 +335,8 @@ Type Parser::parseNonFunctionType() {
|
|||
return parseMemRefType();
|
||||
case Token::kw_tensor:
|
||||
return parseTensorType();
|
||||
case Token::kw_complex:
|
||||
return parseComplexType();
|
||||
case Token::kw_tuple:
|
||||
return parseTupleType();
|
||||
case Token::kw_vector:
|
||||
|
@ -571,6 +575,26 @@ Type Parser::parseTensorType() {
|
|||
return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
|
||||
}
|
||||
|
||||
/// Parse a complex type.
|
||||
///
|
||||
/// complex-type ::= `complex` `<` type `>`
|
||||
///
|
||||
Type Parser::parseComplexType() {
|
||||
consumeToken(Token::kw_complex);
|
||||
|
||||
// Parse the '<'.
|
||||
if (parseToken(Token::less, "expected '<' in complex type"))
|
||||
return nullptr;
|
||||
|
||||
auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
|
||||
auto elementType = parseType();
|
||||
if (!elementType ||
|
||||
parseToken(Token::greater, "expected '>' in complex type"))
|
||||
return nullptr;
|
||||
|
||||
return ComplexType::getChecked(elementType, typeLocation);
|
||||
}
|
||||
|
||||
/// Parse a tuple type.
|
||||
///
|
||||
/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
|
||||
|
|
|
@ -87,11 +87,14 @@ TOK_OPERATOR(star, "*")
|
|||
// TODO: More operator tokens
|
||||
|
||||
// Keywords. These turn "foo" into Token::kw_foo enums.
|
||||
|
||||
// NOTE: Please key these alphabetized to make it easier to find something in
|
||||
// this list and to cater to OCD.
|
||||
TOK_KEYWORD(attributes)
|
||||
TOK_KEYWORD(bf16)
|
||||
TOK_KEYWORD(ceildiv)
|
||||
TOK_KEYWORD(complex)
|
||||
TOK_KEYWORD(dense)
|
||||
TOK_KEYWORD(splat)
|
||||
TOK_KEYWORD(f16)
|
||||
TOK_KEYWORD(f32)
|
||||
TOK_KEYWORD(f64)
|
||||
|
@ -107,13 +110,14 @@ TOK_KEYWORD(min)
|
|||
TOK_KEYWORD(mod)
|
||||
TOK_KEYWORD(opaque)
|
||||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(sparse)
|
||||
TOK_KEYWORD(splat)
|
||||
TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(to)
|
||||
TOK_KEYWORD(true)
|
||||
TOK_KEYWORD(tuple)
|
||||
TOK_KEYWORD(type)
|
||||
TOK_KEYWORD(sparse)
|
||||
TOK_KEYWORD(vector)
|
||||
|
||||
#undef TOK_MARKER
|
||||
|
|
|
@ -1059,3 +1059,18 @@ func @ssa_name_missing_eq() {
|
|||
%0:2 "foo" () : () -> (i32, i32)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{invalid element type for complex}}
|
||||
func @bad_complex(complex<memref<2x4xi8>>)
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '<' in complex type}}
|
||||
func @bad_complex(complex memref<2x4xi8>>)
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '>' in complex type}}
|
||||
func @bad_complex(complex<i32)
|
||||
|
|
|
@ -128,6 +128,10 @@ func @memrefs_drop_triv_id_multiple(memref<2xi8, (d0) -> (d0), (d0) -> (d0)>)
|
|||
func @memrefs_compose_with_id(memref<2x2xi8, (d0, d1) -> (d0, d1),
|
||||
(d0, d1) -> (d1, d0)>)
|
||||
|
||||
|
||||
// CHECK: func @complex_types(complex<i1>) -> complex<f32>
|
||||
func @complex_types(complex<i1>) -> complex<f32>
|
||||
|
||||
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
|
||||
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
|
||||
|
||||
|
|
Loading…
Reference in New Issue