[mlir] expose standard types to C API

Provide C API for MLIR standard types. Since standard types live under lib/IR
in core MLIR, place the C APIs in the IR library as well (standard ops will go
into a separate library). This also defines a placeholder for affine maps that
are necessary to construct a memref, but are not yet exposed to the C API.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D86094
This commit is contained in:
Alex Zinenko 2020-08-18 10:26:30 +02:00
parent 9f63dc3265
commit 74f577845e
13 changed files with 886 additions and 45 deletions

View File

@ -75,6 +75,28 @@ check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
expect null objects as arguments unless explicitly stated otherwise. API
functions _may_ return null objects.
### Type Hierarchies
MLIR objects can form type hierarchies in C++. For example, all IR classes
representing types are derived from `mlir::Type`, some of them may also be also
derived from common base classes such as `mlir::ShapedType` or dialect-specific
base classes. Type hierarchies are exposed to C API through naming conventions
as follows.
- Only the top-level class of each hierarchy is exposed, e.g. `MlirType` is
defined as a type but `MlirShapedType` is not. This avoids the need for
explicit upcasting when passing an object of a derived type to a function
that expects a base type (this happens more often in core/standard APIs,
while downcasting usually involves further checks anyway).
- A type `Y` that derives from `X` provides a function `int mlirXIsAY(MlirX)`
that returns a non-zero value if the given dynamic instance of `X` is also
an instance of `Y`. For example, `int MlirTypeIsAInteger(MlirType)`.
- A function that expects a derived type as its first argument takes the base
type instead and documents the expectation by using `Y` in its name
`MlirY<...>(MlirX, ...)`. This function asserts that the dynamic instance of
its first argument is `Y`, and it is the responsibility of the caller to
ensure it is indeed the case.
### Conversion To String and Printing
IR objects can be converted to a string representation, for example for
@ -96,11 +118,11 @@ allocation and avoid unnecessary allocation and copying inside the printer.
For convenience, `mlirXDump(MlirX)` functions are provided to print the given
object to the standard error stream.
### Common Patterns
## Common Patterns
The API adopts the following patterns for recurrent functionality in MLIR.
#### Indexed Components
### Indexed Components
An object has an _indexed component_ if it has fields accessible using a
zero-based contiguous integer index, typically arrays. For example, an
@ -120,7 +142,7 @@ Note that the name of subobject in the function does not necessarily match the
type of the subobject. For example, `mlirOperationGetOperand` returns a
`MlirValue`.
#### Iterable Components
### Iterable Components
An object has an _iterable component_ if it has iterators accessing its fields
in some order other than integer indexing, typically linked lists. For example,
@ -146,3 +168,17 @@ for (iter = mlirXGetFirst<Y>(x); !mlirYIsNull(iter);
/* User 'iter'. */
}
```
## Extending the API
### Extensions for Dialect Attributes and Types
Dialect attributes and types can follow the example of standard attrbutes and
types, provided that implementations live in separate directories, i.e.
`include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs
provide implementation-private headers in `include/mlir/CAPI/IR` that allow one
to convert between opaque C structures for core IR components and their C++
counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
the inverse conversion. Once the a C++ object is available, the API
implementation should rely on `isa` to implement `mlirXIsAY` and is expected to
use `cast` inside other API calls.

View File

@ -0,0 +1,25 @@
/*===-- mlir-c/AffineMap.h - C API for MLIR Affine maps -----------*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef MLIR_C_AFFINEMAP_H
#define MLIR_C_AFFINEMAP_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
DEFINE_C_API_STRUCT(MlirAffineMap, const void);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_AFFINEMAP_H

View File

@ -56,8 +56,6 @@ DEFINE_C_API_STRUCT(MlirType, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
#undef DEFINE_C_API_STRUCT
/** Named MLIR attribute.
*
* A named attribute is essentially a (name, attribute) pair where the name is
@ -314,6 +312,9 @@ void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
/** Parses a type. The type is owned by the context. */
MlirType mlirTypeParseGet(MlirContext context, const char *type);
/** Checks if two types are equal. */
int mlirTypeEqual(MlirType t1, MlirType t2);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */

View File

@ -0,0 +1,249 @@
/*===-- mlir-c/StandardTypes.h - C API for MLIR Standard types ----*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef MLIR_C_STANDARDTYPES_H
#define MLIR_C_STANDARDTYPES_H
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
/*============================================================================*/
/* Integer types. */
/*============================================================================*/
/** Checks whether the given type is an integer type. */
int mlirTypeIsAInteger(MlirType type);
/** Creates a signless integer type of the given bitwidth in the context. The
* type is owned by the context. */
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth);
/** Creates a signed integer type of the given bitwidth in the context. The type
* is owned by the context. */
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth);
/** Creates an unsigned integer type of the given bitwidth in the context. The
* type is owned by the context. */
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth);
/** Returns the bitwidth of an integer type. */
unsigned mlirIntegerTypeGetWidth(MlirType type);
/** Checks whether the given integer type is signless. */
int mlirIntegerTypeIsSignless(MlirType type);
/** Checks whether the given integer type is signed. */
int mlirIntegerTypeIsSigned(MlirType type);
/** Checks whether the given integer type is unsigned. */
int mlirIntegerTypeIsUnsigned(MlirType type);
/*============================================================================*/
/* Index type. */
/*============================================================================*/
/** Checks whether the given type is an index type. */
int mlirTypeIsAIndex(MlirType type);
/** Creates an index type in the given context. The type is owned by the
* context. */
MlirType mlirIndexTypeGet(MlirContext ctx);
/*============================================================================*/
/* Floating-point types. */
/*============================================================================*/
/** Checks whether the given type is a bf16 type. */
int mlirTypeIsABF16(MlirType type);
/** Creates a bf16 type in the given context. The type is owned by the
* context. */
MlirType mlirBF16TypeGet(MlirContext ctx);
/** Checks whether the given type is an f16 type. */
int mlirTypeIsAF16(MlirType type);
/** Creates an f16 type in the given context. The type is owned by the
* context. */
MlirType mlirF16TypeGet(MlirContext ctx);
/** Checks whether the given type is an f32 type. */
int mlirTypeIsAF32(MlirType type);
/** Creates an f32 type in the given context. The type is owned by the
* context. */
MlirType mlirF32TypeGet(MlirContext ctx);
/** Checks whether the given type is an f64 type. */
int mlirTypeIsAF64(MlirType type);
/** Creates a f64 type in the given context. The type is owned by the
* context. */
MlirType mlirF64TypeGet(MlirContext ctx);
/*============================================================================*/
/* None type. */
/*============================================================================*/
/** Checks whether the given type is a None type. */
int mlirTypeIsANone(MlirType type);
/** Creates a None type in the given context. The type is owned by the
* context. */
MlirType mlirNoneTypeGet(MlirContext ctx);
/*============================================================================*/
/* Complex type. */
/*============================================================================*/
/** Checks whether the given type is a Complex type. */
int mlirTypeIsAComplex(MlirType type);
/** Creates a complex type with the given element type in the same context as
* the element type. The type is owned by the context. */
MlirType mlirComplexTypeGet(MlirType elementType);
/** Returns the element type of the given complex type. */
MlirType mlirComplexTypeGetElementType(MlirType type);
/*============================================================================*/
/* Shaped type. */
/*============================================================================*/
/** Checks whether the given type is a Shaped type. */
int mlirTypeIsAShaped(MlirType type);
/** Returns the element type of the shaped type. */
MlirType mlirShapedTypeGetElementType(MlirType type);
/** Checks whether the given shaped type is ranked. */
int mlirShapedTypeHasRank(MlirType type);
/** Returns the rank of the given ranked shaped type. */
int64_t mlirShapedTypeGetRank(MlirType type);
/** Checks whether the given shaped type has a static shape. */
int mlirShapedTypeHasStaticShape(MlirType type);
/** Checks wither the dim-th dimension of the given shaped type is dynamic. */
int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
/** Returns the dim-th dimension of the given ranked shaped type. */
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim);
/** Checks whether the given value is used as a placeholder for dynamic sizes
* in shaped types. */
int mlirShapedTypeIsDynamicSize(int64_t size);
/** Checks whether the given value is used as a placeholder for dynamic strides
* and offsets in shaped types. */
int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
/*============================================================================*/
/* Vector type. */
/*============================================================================*/
/** Checks whether the given type is a Vector type. */
int mlirTypeIsAVector(MlirType type);
/** Creates a vector type of the shape identified by its rank and dimensios,
* with the given element type in the same context as the element type. The type
* is owned by the context. */
MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType);
/*============================================================================*/
/* Ranked / Unranked Tensor type. */
/*============================================================================*/
/** Checks whether the given type is a Tensor type. */
int mlirTypeIsATensor(MlirType type);
/** Checks whether the given type is a ranked tensor type. */
int mlirTypeIsARankedTensor(MlirType type);
/** Checks whether the given type is an unranked tensor type. */
int mlirTypeIsAUnrankedTensor(MlirType type);
/** Creates a tensor type of a fixed rank with the given shape and element type
* in the same context as the element type. The type is owned by the context. */
MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
MlirType elementType);
/** Creates an unranked tensor type with the given element type in the same
* context as the element type. The type is owned by the context. */
MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
/*============================================================================*/
/* Ranked / Unranked MemRef type. */
/*============================================================================*/
/** Checks whether the given type is a MemRef type. */
int mlirTypeIsAMemRef(MlirType type);
/** Checks whether the given type is an UnrankedMemRef type. */
int mlirTypeIsAUnrankedMemRef(MlirType type);
/** Creates a MemRef type with the given rank and shape, a potentially empty
* list of affine layout maps, the given memory space and element type, in the
* same context as element type. The type is owned by the context. */
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
intptr_t numMaps, MlirAttribute *affineMaps,
unsigned memorySpace);
/** Creates a MemRef type with the given rank, shape, memory space and element
* type in the same context as the element type. The type has no affine maps,
* i.e. represents a default row-major contiguous memref. The type is owned by
* the context. */
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
int64_t *shape, unsigned memorySpace);
/** Creates an Unranked MemRef type with the given element type and in the given
* memory space. The type is owned by the context of element type. */
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace);
/** Returns the number of affine layout maps in the given MemRef type. */
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
/** Returns the pos-th affine map of the given MemRef type. */
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos);
/** Returns the memory space of the given MemRef type. */
unsigned mlirMemRefTypeGetMemorySpace(MlirType type);
/** Returns the memory spcae of the given Unranked MemRef type. */
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type);
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/
/** Checks whether the given type is a tuple type. */
int mlirTypeIsATuple(MlirType type);
/** Creates a tuple type that consists of the given list of elemental types. The
* type is owned by the context. */
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType *elements);
/** Returns the number of types contained in a tuple. */
intptr_t mlirTupleTypeGetNumTypes(MlirType type);
/** Returns the pos-th type in the tuple type. */
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_STANDARDTYPES_H

View File

@ -0,0 +1,24 @@
//===- AffineMap.h - C API Utils for Affine Maps ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of implementation details of the C API for
// MLIR Affine maps. This file should not be included from C++ code other than
// C API implementation nor from C code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CAPI_AFFINEMAP_H
#define MLIR_CAPI_AFFINEMAP_H
#include "mlir-c/AffineMap.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/AffineMap.h"
DEFINE_C_API_METHODS(MlirAffineMap, mlir::AffineMap)
#endif // MLIR_CAPI_AFFINEMAP_H

View File

@ -0,0 +1,34 @@
//===- IR.h - C API Utils for Core MLIR classes -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of implementation details of the C API for
// core MLIR classes. This file should not be included from C++ code other than
// C API implementation nor from C code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INCLUDE_MLIR_CAPI_IR_H
#define MLIR_INCLUDE_MLIR_CAPI_IR_H
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
DEFINE_C_API_METHODS(MlirValue, mlir::Value)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
#endif // MLIR_INCLUDE_MLIR_CAPI_IR_H

View File

@ -0,0 +1,56 @@
//===- Wrap.h - C API Utilities ---------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains common definitions for wrapping opaque C++ pointers into
// C structures for the purpose of C API. This file should not be included from
// C++ code other than C API implementation nor from C code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CAPI_WRAP_H
#define MLIR_CAPI_WRAP_H
#include "mlir-c/IR.h"
#include "mlir/Support/LLVM.h"
/* ========================================================================== */
/* Definitions of methods for non-owning structures used in C API. */
/* ========================================================================== */
#define DEFINE_C_API_PTR_METHODS(name, cpptype) \
static inline name wrap(cpptype *cpp) { return name{cpp}; } \
static inline cpptype *unwrap(name c) { \
return static_cast<cpptype *>(c.ptr); \
}
#define DEFINE_C_API_METHODS(name, cpptype) \
static inline name wrap(cpptype cpp) { \
return name{cpp.getAsOpaquePointer()}; \
} \
static inline cpptype unwrap(name c) { \
return cpptype::getFromOpaquePointer(c.ptr); \
}
template <typename CppTy, typename CTy>
static llvm::ArrayRef<CppTy> unwrapList(size_t size, CTy *first,
llvm::SmallVectorImpl<CppTy> &storage) {
static_assert(
std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
"incompatible C and C++ types");
if (size == 0)
return llvm::None;
assert(storage.empty() && "expected to populate storage");
storage.reserve(size);
for (size_t i = 0; i < size; ++i)
storage.push_back(unwrap(*(first + i)));
return storage;
}
#endif // MLIR_CAPI_WRAP_H

View File

@ -196,6 +196,14 @@ public:
friend ::llvm::hash_code hash_value(AffineMap arg);
/// Methods supporting C API.
const void *getAsOpaquePointer() const {
return static_cast<const void *>(map);
}
static AffineMap getFromOpaquePointer(const void *pointer) {
return AffineMap(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
private:
ImplType *map;

View File

@ -0,0 +1,15 @@
//===- AffineMap.cpp - C API for MLIR Affine Maps -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/IR/AffineMap.h"
// This is a placeholder for affine map bindings. The file is here to serve as a
// compilation unit that includes the headers.

View File

@ -1,6 +1,8 @@
# Main API.
add_mlir_library(MLIRCAPIIR
AffineMap.cpp
IR.cpp
StandardTypes.cpp
EXCLUDE_FROM_LIBMLIR

View File

@ -8,6 +8,7 @@
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
@ -17,46 +18,6 @@
using namespace mlir;
/* ========================================================================== */
/* Definitions of methods for non-owning structures used in C API. */
/* ========================================================================== */
#define DEFINE_C_API_PTR_METHODS(name, cpptype) \
static name wrap(cpptype *cpp) { return name{cpp}; } \
static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
#define DEFINE_C_API_METHODS(name, cpptype) \
static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \
static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
DEFINE_C_API_METHODS(MlirAttribute, Attribute)
DEFINE_C_API_METHODS(MlirLocation, Location);
DEFINE_C_API_METHODS(MlirType, Type)
DEFINE_C_API_METHODS(MlirValue, Value)
DEFINE_C_API_METHODS(MlirModule, ModuleOp)
template <typename CppTy, typename CTy>
static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
SmallVectorImpl<CppTy> &storage) {
static_assert(
std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
"incompatible C and C++ types");
if (size == 0)
return llvm::None;
assert(storage.empty() && "expected to populate storage");
storage.reserve(size);
for (intptr_t i = 0; i < size; ++i)
storage.push_back(unwrap(*(first + i)));
return storage;
}
/* ========================================================================== */
/* Printing helper. */
/* ========================================================================== */
@ -388,6 +349,8 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
return wrap(mlir::parseType(type, unwrap(context)));
}
int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
CallbackOstream stream(callback, userData);
unwrap(type).print(stream);

View File

@ -0,0 +1,263 @@
//===- StandardTypes.cpp - C Interface to MLIR Standard Types -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/StandardTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
/* ========================================================================== */
/* Integer types. */
/* ========================================================================== */
int mlirTypeIsAInteger(MlirType type) {
return unwrap(type).isa<IntegerType>();
}
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
}
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
}
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
}
unsigned mlirIntegerTypeGetWidth(MlirType type) {
return unwrap(type).cast<IntegerType>().getWidth();
}
int mlirIntegerTypeIsSignless(MlirType type) {
return unwrap(type).cast<IntegerType>().isSignless();
}
int mlirIntegerTypeIsSigned(MlirType type) {
return unwrap(type).cast<IntegerType>().isSigned();
}
int mlirIntegerTypeIsUnsigned(MlirType type) {
return unwrap(type).cast<IntegerType>().isUnsigned();
}
/* ========================================================================== */
/* Index type. */
/* ========================================================================== */
int mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
MlirType mlirIndexTypeGet(MlirContext ctx) {
return wrap(IndexType::get(unwrap(ctx)));
}
/* ========================================================================== */
/* Floating-point types. */
/* ========================================================================== */
int mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getBF16(unwrap(ctx)));
}
int mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
int mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getF32(unwrap(ctx)));
}
int mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(FloatType::getF64(unwrap(ctx)));
}
/* ========================================================================== */
/* None type. */
/* ========================================================================== */
int mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
MlirType mlirNoneTypeGet(MlirContext ctx) {
return wrap(NoneType::get(unwrap(ctx)));
}
/* ========================================================================== */
/* Complex type. */
/* ========================================================================== */
int mlirTypeIsAComplex(MlirType type) {
return unwrap(type).isa<ComplexType>();
}
MlirType mlirComplexTypeGet(MlirType elementType) {
return wrap(ComplexType::get(unwrap(elementType)));
}
MlirType mlirComplexTypeGetElementType(MlirType type) {
return wrap(unwrap(type).cast<ComplexType>().getElementType());
}
/* ========================================================================== */
/* Shaped type. */
/* ========================================================================== */
int mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
MlirType mlirShapedTypeGetElementType(MlirType type) {
return wrap(unwrap(type).cast<ShapedType>().getElementType());
}
int mlirShapedTypeHasRank(MlirType type) {
return unwrap(type).cast<ShapedType>().hasRank();
}
int64_t mlirShapedTypeGetRank(MlirType type) {
return unwrap(type).cast<ShapedType>().getRank();
}
int mlirShapedTypeHasStaticShape(MlirType type) {
return unwrap(type).cast<ShapedType>().hasStaticShape();
}
int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
return unwrap(type).cast<ShapedType>().isDynamicDim(
static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
}
int mlirShapedTypeIsDynamicSize(int64_t size) {
return ShapedType::isDynamic(size);
}
int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}
/* ========================================================================== */
/* Vector type. */
/* ========================================================================== */
int mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape,
MlirType elementType) {
return wrap(
VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
/* ========================================================================== */
/* Ranked / Unranked tensor type. */
/* ========================================================================== */
int mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
int mlirTypeIsARankedTensor(MlirType type) {
return unwrap(type).isa<RankedTensorType>();
}
int mlirTypeIsAUnrankedTensor(MlirType type) {
return unwrap(type).isa<UnrankedTensorType>();
}
MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
MlirType elementType) {
return wrap(RankedTensorType::get(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
/* ========================================================================== */
/* Ranked / Unranked MemRef type. */
/* ========================================================================== */
int mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
intptr_t numMaps, MlirAffineMap *affineMaps,
unsigned memorySpace) {
SmallVector<AffineMap, 1> maps;
(void)unwrapList(numMaps, affineMaps, maps);
return wrap(
MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), maps, memorySpace));
}
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
int64_t *shape, unsigned memorySpace) {
return wrap(
MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), llvm::None, memorySpace));
}
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
return static_cast<intptr_t>(
unwrap(type).cast<MemRefType>().getAffineMaps().size());
}
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
}
unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
return unwrap(type).cast<MemRefType>().getMemorySpace();
}
int mlirTypeIsAUnrankedMemRef(MlirType type) {
return unwrap(type).isa<UnrankedMemRefType>();
}
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
}
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
}
/* ========================================================================== */
/* Tuple type. */
/* ========================================================================== */
int mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType *elements) {
SmallVector<Type, 4> types;
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
return wrap(TupleType::get(typeRef, unwrap(ctx)));
}
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
return unwrap(type).cast<TupleType>().size();
}
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
}

View File

@ -12,6 +12,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardTypes.h"
#include <assert.h>
#include <stdio.h>
@ -240,6 +241,145 @@ static void printFirstOfEach(MlirOperation operation) {
fprintf(stderr, "\n");
}
/// Dumps instances of all standard types to check that C API works correctly.
/// Additionally, performs simple identity checks that a standard type
/// constructed with C API can be inspected and has the expected type. The
/// latter achieves full coverage of C API for standard types. Returns 0 on
/// success and a non-zero error code on failure.
static int printStandardTypes(MlirContext ctx) {
// Integer types.
MlirType i32 = mlirIntegerTypeGet(ctx, 32);
MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
return 1;
if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
return 2;
if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
return 3;
if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
return 4;
if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
return 5;
mlirTypeDump(i32);
fprintf(stderr, "\n");
mlirTypeDump(si32);
fprintf(stderr, "\n");
mlirTypeDump(ui32);
fprintf(stderr, "\n");
// Index type.
MlirType index = mlirIndexTypeGet(ctx);
if (!mlirTypeIsAIndex(index))
return 6;
mlirTypeDump(index);
fprintf(stderr, "\n");
// Floating-point types.
MlirType bf16 = mlirBF16TypeGet(ctx);
MlirType f16 = mlirF16TypeGet(ctx);
MlirType f32 = mlirF32TypeGet(ctx);
MlirType f64 = mlirF64TypeGet(ctx);
if (!mlirTypeIsABF16(bf16))
return 7;
if (!mlirTypeIsAF16(f16))
return 9;
if (!mlirTypeIsAF32(f32))
return 10;
if (!mlirTypeIsAF64(f64))
return 11;
mlirTypeDump(bf16);
fprintf(stderr, "\n");
mlirTypeDump(f16);
fprintf(stderr, "\n");
mlirTypeDump(f32);
fprintf(stderr, "\n");
mlirTypeDump(f64);
fprintf(stderr, "\n");
// None type.
MlirType none = mlirNoneTypeGet(ctx);
if (!mlirTypeIsANone(none))
return 12;
mlirTypeDump(none);
fprintf(stderr, "\n");
// Complex type.
MlirType cplx = mlirComplexTypeGet(f32);
if (!mlirTypeIsAComplex(cplx) ||
!mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
return 13;
mlirTypeDump(cplx);
fprintf(stderr, "\n");
// Vector (and Shaped) type. ShapedType is a common base class for vectors,
// memrefs and tensors, one cannot create instances of this class so it is
// tested on an instance of vector type.
int64_t shape[] = {2, 3};
MlirType vector =
mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
return 14;
if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
!mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
mlirShapedTypeGetDimSize(vector, 0) != 2 ||
mlirShapedTypeIsDynamicDim(vector, 0) ||
mlirShapedTypeGetDimSize(vector, 1) != 3 ||
!mlirShapedTypeHasStaticShape(vector))
return 15;
mlirTypeDump(vector);
fprintf(stderr, "\n");
// Ranked tensor type.
MlirType rankedTensor =
mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor))
return 16;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
// Unranked tensor type.
MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
if (!mlirTypeIsATensor(unrankedTensor) ||
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
mlirShapedTypeHasRank(unrankedTensor))
return 17;
mlirTypeDump(unrankedTensor);
fprintf(stderr, "\n");
// MemRef type.
MlirType memRef = mlirMemRefTypeContiguousGet(
f32, sizeof(shape) / sizeof(int64_t), shape, 2);
if (!mlirTypeIsAMemRef(memRef) ||
mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
mlirMemRefTypeGetMemorySpace(memRef) != 2)
return 18;
mlirTypeDump(memRef);
fprintf(stderr, "\n");
// Unranked MemRef type.
MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
mlirTypeIsAMemRef(unrankedMemRef) ||
mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
return 19;
mlirTypeDump(unrankedMemRef);
fprintf(stderr, "\n");
// Tuple type.
MlirType types[] = {unrankedMemRef, f32};
MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
return 20;
mlirTypeDump(tuple);
fprintf(stderr, "\n");
return 0;
}
int main() {
mlirRegisterAllDialects();
MlirContext ctx = mlirContextCreate();
@ -293,6 +433,31 @@ int main() {
// clang-format on
mlirModuleDestroy(moduleOp);
// clang-format off
// CHECK-LABEL: @types
// CHECK: i32
// CHECK: si32
// CHECK: ui32
// CHECK: index
// CHECK: bf16
// CHECK: f16
// CHECK: f32
// CHECK: f64
// CHECK: none
// CHECK: complex<f32>
// CHECK: vector<2x3xf32>
// CHECK: tensor<2x3xf32>
// CHECK: tensor<*xf32>
// CHECK: memref<2x3xf32, 2>
// CHECK: memref<*xf32, 4>
// CHECK: tuple<memref<*xf32, 4>, f32>
// CHECK: 0
// clang-format on
fprintf(stderr, "@types");
int errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode);
mlirContextDestroy(ctx);
return 0;