[mlir] factor out common parts of the converstion to the LLVM dialect

"Standard-to-LLVM" conversion is one of the oldest passes in existence. It has
become quite large due to the size of the Standard dialect itself, which is
being split into multiple smaller dialects. Furthermore, several conversion
features are useful for any dialect that is being converted to the LLVM
dialect, which, without this refactoring, creates a dependency from those
conversions to the "standard-to-llvm" one.

Put several of the reusable utilities from this conversion to a separate
library, namely:
- type converter from builtin to LLVM dialect types;
- utility for building and accessing values of LLVM structure type;
- utility for building and accessing values that represent memref in the LLVM
  dialect;
- lowering options applicable everywhere.

Additionally, remove the type wrapping/unwrapping notion from the type
converter that is no longer relevant since LLVM types has been reimplemented as
first-class MLIR types.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D105534
This commit is contained in:
Alex Zinenko 2021-07-07 09:46:27 +02:00
parent 0c4e538d8f
commit b5d847b1b9
28 changed files with 1803 additions and 1581 deletions

View File

@ -8,15 +8,34 @@
#ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ #ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ #define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
class MLIRContext; class LLVMTypeConverter;
class ModuleOp; class ModuleOp;
template <typename T> template <typename T>
class OperationPass; class OperationPass;
class ComplexStructBuilder : public StructBuilder {
public:
/// Construct a helper for the given complex number value.
using StructBuilder::StructBuilder;
/// Build IR creating an `undef` value of the complex number type.
static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
Type type);
// Build IR extracting the real value from the complex number struct.
Value real(OpBuilder &builder, Location loc);
// Build IR inserting the real value into the complex number struct.
void setReal(OpBuilder &builder, Location loc, Value real);
// Build IR extracting the imaginary value from the complex number struct.
Value imaginary(OpBuilder &builder, Location loc);
// Build IR inserting the imaginary value into the complex number struct.
void setImaginary(OpBuilder &builder, Location loc, Value imaginary);
};
/// Populate the given list with patterns that convert from Complex to LLVM. /// Populate the given list with patterns that convert from Complex to LLVM.
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns); RewritePatternSet &patterns);

View File

@ -8,7 +8,7 @@
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {

View File

@ -8,7 +8,7 @@
#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {

View File

@ -0,0 +1,73 @@
//===- LoweringOptions.h - Common config for lowering to LLVM ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides a configuration shared by several conversions targeting the LLVM
// dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
#define MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
#include "llvm/IR/DataLayout.h"
namespace mlir {
class DataLayout;
class MLIRContext;
/// Value to pass as bitwidth for the index type when the converter is expected
/// to derive the bitwidth from the LLVM data layout.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
/// Options to control the Standard dialect to LLVM lowering. The struct is used
/// to share lowering options between passes, patterns, and type converter.
class LowerToLLVMOptions {
public:
explicit LowerToLLVMOptions(MLIRContext *ctx);
LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl);
bool useBarePtrCallConv = false;
bool emitCWrappers = false;
enum class AllocLowering {
/// Use malloc for for heap allocations.
Malloc,
/// Use aligned_alloc for heap allocations.
AlignedAlloc,
/// Do not lower heap allocations. Users must provide their own patterns for
/// AllocOp and DeallocOp lowering.
None
};
AllocLowering allocLowering = AllocLowering::Malloc;
/// The data layout of the module to produce. This must be consistent with the
/// data layout used in the upper levels of the lowering pipeline.
// TODO: this should be replaced by MLIR data layout when one exists.
llvm::DataLayout dataLayout = llvm::DataLayout("");
/// Set the index bitwidth to the given value.
void overrideIndexBitwidth(unsigned bitwidth) {
assert(bitwidth != kDeriveIndexBitwidthFromDataLayout &&
"can only override to a concrete bitwidth");
indexBitwidth = bitwidth;
}
/// Get the index bitwidth.
unsigned getIndexBitwidth() const { return indexBitwidth; }
private:
unsigned indexBitwidth;
};
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H

View File

@ -0,0 +1,245 @@
//===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides a convenience API for emitting IR that inspects or constructs values
// of LLVM dialect structure type that correspond to ranked or unranked memref.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
#define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
class LLVMTypeConverter;
class MemRefType;
class UnrankedMemRefType;
namespace LLVM {
class LLVMPointerType;
} // namespace LLVM
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the aligned pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the offset from the descriptor.
Value offset(OpBuilder &builder, Location loc);
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset);
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
/// Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos);
Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size);
/// Builds IR extracting the pos-th size from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values);
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type, SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues(MemRefType type);
private:
// Cached index type.
Type indexType;
};
/// Helper class allowing the user to access a range of Values that correspond
/// to an unpacked memref descriptor using named accessors. This does not own
/// the values.
class MemRefDescriptorView {
public:
/// Constructs the view from a range of values. Infers the rank from the size
/// of the range.
explicit MemRefDescriptorView(ValueRange range);
/// Returns the allocated pointer Value.
Value allocatedPtr();
/// Returns the aligned pointer Value.
Value alignedPtr();
/// Returns the offset Value.
Value offset();
/// Returns the pos-th size Value.
Value size(unsigned pos);
/// Returns the pos-th stride Value.
Value stride(unsigned pos);
private:
/// Rank of the memref the descriptor is pointing to.
int rank;
/// Underlying range of Values.
ValueRange elements;
};
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc);
/// Builds IR setting the rank in the descriptor
void setRank(OpBuilder &builder, Location loc, Value value);
/// Builds IR extracting ranked memref descriptor ptr
Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, UnrankedMemRefType type,
ValueRange values);
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
/// and appends the corresponding values into `sizes`.
static void computeSizes(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
SmallVectorImpl<Value> &sizes);
/// TODO: The following accessors don't take alignment rules between elements
/// of the descriptor struct into account. For some architectures, it might be
/// necessary to extend them and to use `llvm::DataLayout` contained in
/// `LLVMTypeConverter`.
/// Builds IR extracting the allocated pointer from the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr, Type elemPtrPtrType);
/// Builds IR inserting the allocated pointer into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr, Type elemPtrPtrType,
Value allocatedPtr);
/// Builds IR extracting the aligned pointer from the descriptor.
static Value alignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType);
/// Builds IR inserting the aligned pointer into the descriptor.
static void setAlignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr, Type elemPtrPtrType,
Value alignedPtr);
/// Builds IR extracting the offset from the descriptor.
static Value offset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType);
/// Builds IR inserting the offset into the descriptor.
static void setOffset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType, Value offset);
/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrPtrType);
/// Builds IR extracting the size[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
Value index);
/// Builds IR inserting the size[index] into the descriptor.
static void setSize(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
Value index, Value size);
/// Builds IR extracting the pointer to the first element of the stride array.
static Value strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank);
/// Builds IR extracting the stride[index] from the descriptor.
static Value stride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value strideBasePtr,
Value index, Value stride);
/// Builds IR inserting the stride[index] into the descriptor.
static void setStride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value strideBasePtr,
Value index, Value stride);
};
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H_

View File

@ -0,0 +1,51 @@
//===- StructBuilder.h - Helper for building LLVM structs -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides a convenience API for emitting IR that inspects or constructs values
// of LLVM dialect structure types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H
#define MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace mlir {
class OpBuilder;
/// Helper class to produce LLVM dialect operations extracting or inserting
/// values to a struct.
class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/*implicit*/ operator Value() { return value; }
protected:
// LLVM value
Value value;
// Cached struct type.
Type structType;
protected:
/// Builds IR to extract a value from the struct at position pos
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H

View File

@ -0,0 +1,227 @@
//===- TypeConverter.h - Convert builtin to LLVM dialect types --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides a type converter configuration for converting most builtin types to
// LLVM dialect types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class DataLayoutAnalysis;
class LowerToLLVMOptions;
namespace LLVM {
class LLVMDialect;
} // namespace LLVM
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result);
public:
using TypeConverter::convertType;
/// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
/// Optionally takes a data layout analysis to use in conversions.
LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis = nullptr);
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally
/// takes a data layout analysis to use in conversions.
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
const DataLayoutAnalysis *analysis = nullptr);
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one value is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(TypeRange types);
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
Type convertCallingConventionType(Type type);
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<Type> stdTypes,
SmallVectorImpl<Value> &values);
/// Returns the MLIR context.
MLIRContext &getContext();
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
const LowerToLLVMOptions &getOptions() const { return options; }
/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
ValueRange operands,
OpBuilder &builder);
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
/// C/C++ ABI lowering related to struct argument passing.
Value promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments. Also converts the return
/// type to a pointer argument if it is a struct. Returns true if this
/// was the case.
std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
/// Returns the data layout to use during and after conversion.
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
/// Returns the data layout analysis to query during conversion.
const DataLayoutAnalysis *getDataLayoutAnalysis() const {
return dataLayoutAnalysis;
}
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
Type getIndexType();
/// Gets the bitwidth of the index type when converted to LLVM.
unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); }
/// Gets the pointer bitwidth.
unsigned getPointerBitwidth(unsigned addressSpace = 0);
/// Returns the size of the memref descriptor object in bytes.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout);
/// Returns the size of the unranked memref descriptor object in bytes.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout);
protected:
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
private:
/// Convert a function type. The arguments and results are converted one by
/// one. Additionally, if the function returns more than one value, pack the
/// results into an LLVM IR structure type so that the converted function type
/// returns at most one result.
Type convertFunctionType(FunctionType type);
/// Convert the index type. Uses llvmModule data layout to create an integer
/// of the pointer bitwidth.
Type convertIndexType(IndexType type);
/// Convert an integer type `i*` to `!llvm<"i*">`.
Type convertIntegerType(IntegerType type);
/// Convert a floating point type: `f16` to `f16`, `f32` to
/// `f32` and `f64` to `f64`. `bf16` is not supported
/// by LLVM.
Type convertFloatType(FloatType type);
/// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
/// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
/// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
Type convertComplexType(ComplexType type);
/// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type);
/// Convert a memref type into a list of LLVM IR types that will form the
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
/// arrays in the descriptors are unpacked to individual index-typed elements,
/// else they are are kept as rank-sized arrays of index type. In particular,
/// the list will contain:
/// - two pointers to the memref element type, followed by
/// - an index-typed offset, followed by
/// - (if unpackAggregates = true)
/// - one index-typed size per dimension of the memref, followed by
/// - one index-typed stride per dimension of the memref.
/// - (if unpackArrregates = false)
/// - one rank-sized array of index-type for the size of each dimension
/// - one rank-sized array of index-type for the stride of each dimension
///
/// For example, memref<?x?xf32> is converted to the following list:
/// - `!llvm<"float*">` (allocated pointer),
/// - `!llvm<"float*">` (aligned pointer),
/// - `i64` (offset),
/// - `i64`, `i64` (sizes),
/// - `i64`, `i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates);
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that will form the unranked memref descriptor. In particular, this list
/// contains:
/// - an integer rank, followed by
/// - a pointer to the memref descriptor struct.
/// For example, memref<*xf32> is converted to the following list:
/// i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
/// Convert a memref type to a bare pointer to the memref element type.
Type convertMemRefToBarePtr(BaseMemRefType type);
/// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
/// Data layout analysis mapping scopes to layouts active in them.
const DataLayoutAnalysis *dataLayoutAnalysis;
};
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H

View File

@ -15,6 +15,8 @@
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -38,458 +40,7 @@ class LLVMDialect;
class LLVMPointerType; class LLVMPointerType;
} // namespace LLVM } // namespace LLVM
/// Callback to convert function argument types. It converts a MemRef function // ------------------
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result);
public:
using TypeConverter::convertType;
/// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
/// Optionally takes a data layout analysis to use in conversions.
LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis = nullptr);
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally
/// takes a data layout analysis to use in conversions.
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
const DataLayoutAnalysis *analysis = nullptr);
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one value is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(TypeRange types);
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
Type convertCallingConventionType(Type type);
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<Type> stdTypes,
SmallVectorImpl<Value> &values);
/// Returns the MLIR context.
MLIRContext &getContext();
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
const LowerToLLVMOptions &getOptions() const { return options; }
/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
ValueRange operands,
OpBuilder &builder);
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
/// C/C++ ABI lowering related to struct argument passing.
Value promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments. Also converts the return
/// type to a pointer argument if it is a struct. Returns true if this
/// was the case.
std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
/// Returns the data layout to use during and after conversion.
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
/// Returns the data layout analysis to query during conversion.
const DataLayoutAnalysis *getDataLayoutAnalysis() const {
return dataLayoutAnalysis;
}
/// Gets the LLVM representation of the index type. The returned type is an
/// integer type with the size configured for this type converter.
Type getIndexType();
/// Gets the bitwidth of the index type when converted to LLVM.
unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); }
/// Gets the pointer bitwidth.
unsigned getPointerBitwidth(unsigned addressSpace = 0);
/// Returns the size of the memref descriptor object in bytes.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout);
/// Returns the size of the unranked memref descriptor object in bytes.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout);
protected:
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
private:
/// Convert a function type. The arguments and results are converted one by
/// one. Additionally, if the function returns more than one value, pack the
/// results into an LLVM IR structure type so that the converted function type
/// returns at most one result.
Type convertFunctionType(FunctionType type);
/// Convert the index type. Uses llvmModule data layout to create an integer
/// of the pointer bitwidth.
Type convertIndexType(IndexType type);
/// Convert an integer type `i*` to `!llvm<"i*">`.
Type convertIntegerType(IntegerType type);
/// Convert a floating point type: `f16` to `f16`, `f32` to
/// `f32` and `f64` to `f64`. `bf16` is not supported
/// by LLVM.
Type convertFloatType(FloatType type);
/// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
/// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
/// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
Type convertComplexType(ComplexType type);
/// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type);
/// Convert a memref type into a list of LLVM IR types that will form the
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
/// arrays in the descriptors are unpacked to individual index-typed elements,
/// else they are are kept as rank-sized arrays of index type. In particular,
/// the list will contain:
/// - two pointers to the memref element type, followed by
/// - an index-typed offset, followed by
/// - (if unpackAggregates = true)
/// - one index-typed size per dimension of the memref, followed by
/// - one index-typed stride per dimension of the memref.
/// - (if unpackArrregates = false)
/// - one rank-sized array of index-type for the size of each dimension
/// - one rank-sized array of index-type for the stride of each dimension
///
/// For example, memref<?x?xf32> is converted to the following list:
/// - `!llvm<"float*">` (allocated pointer),
/// - `!llvm<"float*">` (aligned pointer),
/// - `i64` (offset),
/// - `i64`, `i64` (sizes),
/// - `i64`, `i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates);
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that will form the unranked memref descriptor. In particular, this list
/// contains:
/// - an integer rank, followed by
/// - a pointer to the memref descriptor struct.
/// For example, memref<*xf32> is converted to the following list:
/// i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
/// Convert a memref type to a bare pointer to the memref element type.
Type convertMemRefToBarePtr(BaseMemRefType type);
/// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
/// Data layout analysis mapping scopes to layouts active in them.
const DataLayoutAnalysis *dataLayoutAnalysis;
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// values to a struct.
class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/*implicit*/ operator Value() { return value; }
protected:
// LLVM value
Value value;
// Cached struct type.
Type structType;
protected:
/// Builds IR to extract a value from the struct at position pos
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
class ComplexStructBuilder : public StructBuilder {
public:
/// Construct a helper for the given complex number value.
using StructBuilder::StructBuilder;
/// Build IR creating an `undef` value of the complex number type.
static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
Type type);
// Build IR extracting the real value from the complex number struct.
Value real(OpBuilder &builder, Location loc);
// Build IR inserting the real value into the complex number struct.
void setReal(OpBuilder &builder, Location loc, Value real);
// Build IR extracting the imaginary value from the complex number struct.
Value imaginary(OpBuilder &builder, Location loc);
// Build IR inserting the imaginary value into the complex number struct.
void setImaginary(OpBuilder &builder, Location loc, Value imaginary);
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the aligned pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the offset from the descriptor.
Value offset(OpBuilder &builder, Location loc);
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset);
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
/// Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos);
Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size);
/// Builds IR extracting the pos-th size from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
/// Returns the (LLVM) pointer type this descriptor contains.
LLVM::LLVMPointerType getElementPtrType();
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values);
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type, SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues(MemRefType type);
private:
// Cached index type.
Type indexType;
};
/// Helper class allowing the user to access a range of Values that correspond
/// to an unpacked memref descriptor using named accessors. This does not own
/// the values.
class MemRefDescriptorView {
public:
/// Constructs the view from a range of values. Infers the rank from the size
/// of the range.
explicit MemRefDescriptorView(ValueRange range);
/// Returns the allocated pointer Value.
Value allocatedPtr();
/// Returns the aligned pointer Value.
Value alignedPtr();
/// Returns the offset Value.
Value offset();
/// Returns the pos-th size Value.
Value size(unsigned pos);
/// Returns the pos-th stride Value.
Value stride(unsigned pos);
private:
/// Rank of the memref the descriptor is pointing to.
int rank;
/// Underlying range of Values.
ValueRange elements;
};
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc);
/// Builds IR setting the rank in the descriptor
void setRank(OpBuilder &builder, Location loc, Value value);
/// Builds IR extracting ranked memref descriptor ptr
Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, UnrankedMemRefType type,
ValueRange values);
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
/// and appends the corresponding values into `sizes`.
static void computeSizes(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
SmallVectorImpl<Value> &sizes);
/// TODO: The following accessors don't take alignment rules between elements
/// of the descriptor struct into account. For some architectures, it might be
/// necessary to extend them and to use `llvm::DataLayout` contained in
/// `LLVMTypeConverter`.
/// Builds IR extracting the allocated pointer from the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr, Type elemPtrPtrType);
/// Builds IR inserting the allocated pointer into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr, Type elemPtrPtrType,
Value allocatedPtr);
/// Builds IR extracting the aligned pointer from the descriptor.
static Value alignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType);
/// Builds IR inserting the aligned pointer into the descriptor.
static void setAlignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr, Type elemPtrPtrType,
Value alignedPtr);
/// Builds IR extracting the offset from the descriptor.
static Value offset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType);
/// Builds IR inserting the offset into the descriptor.
static void setOffset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter, Value memRefDescPtr,
Type elemPtrPtrType, Value offset);
/// Builds IR extracting the pointer to the first element of the size array.
static Value sizeBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
LLVM::LLVMPointerType elemPtrPtrType);
/// Builds IR extracting the size[index] from the descriptor.
static Value size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
Value index);
/// Builds IR inserting the size[index] into the descriptor.
static void setSize(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value sizeBasePtr,
Value index, Value size);
/// Builds IR extracting the pointer to the first element of the stride array.
static Value strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank);
/// Builds IR extracting the stride[index] from the descriptor.
static Value stride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value strideBasePtr,
Value index, Value stride);
/// Builds IR inserting the stride[index] into the descriptor.
static void setStride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter, Value strideBasePtr,
Value index, Value stride);
};
/// Base class for operation conversions targeting the LLVM IR dialect. It /// Base class for operation conversions targeting the LLVM IR dialect. It
/// provides the conversion patterns with access to the LLVMTypeConverter and /// provides the conversion patterns with access to the LLVMTypeConverter and

View File

@ -9,67 +9,17 @@
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
#include "llvm/IR/DataLayout.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
class DataLayout;
class LLVMTypeConverter; class LLVMTypeConverter;
class MLIRContext; class LowerToLLVMOptions;
class ModuleOp; class ModuleOp;
template <typename T> template <typename T>
class OperationPass; class OperationPass;
class RewritePatternSet; class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet; using OwningRewritePatternList = RewritePatternSet;
/// Value to pass as bitwidth for the index type when the converter is expected
/// to derive the bitwidth from the LLVM data layout.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0;
/// Options to control the Standard dialect to LLVM lowering. The struct is used
/// to share lowering options between passes, patterns, and type converter.
class LowerToLLVMOptions {
public:
explicit LowerToLLVMOptions(MLIRContext *ctx);
explicit LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl);
bool useBarePtrCallConv = false;
bool emitCWrappers = false;
enum class AllocLowering {
/// Use malloc for for heap allocations.
Malloc,
/// Use aligned_alloc for heap allocations.
AlignedAlloc,
/// Do not lower heap allocations. Users must provide their own patterns for
/// AllocOp and DeallocOp lowering.
None
};
AllocLowering allocLowering = AllocLowering::Malloc;
/// The data layout of the module to produce. This must be consistent with the
/// data layout used in the upper levels of the lowering pipeline.
// TODO: this should be replaced by MLIR data layout when one exists.
llvm::DataLayout dataLayout = llvm::DataLayout("");
/// Set the index bitwidth to the given value.
void overrideIndexBitwidth(unsigned bitwidth) {
assert(bitwidth != kDeriveIndexBitwidthFromDataLayout &&
"can only override to a concrete bitwidth");
indexBitwidth = bitwidth;
}
/// Get the index bitwidth.
unsigned getIndexBitwidth() const { return indexBitwidth; }
private:
unsigned indexBitwidth;
};
/// Collect a set of patterns to convert memory-related operations from the /// Collect a set of patterns to convert memory-related operations from the
/// Standard dialect to the LLVM dialect, excluding non-memory-related /// Standard dialect to the LLVM dialect, excluding non-memory-related
/// operations and FuncOp. /// operations and FuncOp.

View File

@ -11,6 +11,7 @@ add_subdirectory(GPUToVulkan)
add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard) add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToLibm) add_subdirectory(MathToLibm)
add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToLLVM)
add_subdirectory(OpenACCToSCF) add_subdirectory(OpenACCToSCF)

View File

@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRComplexToLLVM
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRComplex MLIRComplex
MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRStandardOpsTransforms MLIRStandardOpsTransforms
MLIRStandardToLLVM MLIRStandardToLLVM

View File

@ -9,12 +9,48 @@
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
using namespace mlir; using namespace mlir;
using namespace mlir::LLVM; using namespace mlir::LLVM;
//===----------------------------------------------------------------------===//
// ComplexStructBuilder implementation.
//===----------------------------------------------------------------------===//
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::UndefOp>(loc, type);
return ComplexStructBuilder(val);
}
void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
Value real) {
setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
}
Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
}
void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
Value imaginary) {
setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
}
Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
}
//===----------------------------------------------------------------------===//
// Conversion patterns.
//===----------------------------------------------------------------------===//
namespace { namespace {
struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {

View File

@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRGPUOps MLIRGPUOps
MLIRGPUToGPURuntimeTransforms MLIRGPUToGPURuntimeTransforms
MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRMemRef MLIRMemRef
MLIRNVVMIR MLIRNVVMIR

View File

@ -13,7 +13,9 @@
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h"

View File

@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRGPUOps MLIRGPUOps
MLIRGPUToGPURuntimeTransforms MLIRGPUToGPURuntimeTransforms
MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRROCDLIR MLIRROCDLIR
MLIRPass MLIRPass

View File

@ -13,7 +13,9 @@
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"

View File

@ -0,0 +1,15 @@
add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
StructBuilder.cpp
TypeConverter.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTransforms
)

View File

@ -0,0 +1,21 @@
//===- LoweringOptions.cpp - Common config for lowering to LLVM ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
using namespace mlir;
mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)
: LowerToLLVMOptions(ctx, DataLayout()) {}
mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx,
const DataLayout &dl) {
indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx));
}

View File

@ -0,0 +1,525 @@
//===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "MemRefDescriptor.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/MathExtras.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// MemRefDescriptor implementation
//===----------------------------------------------------------------------===//
/// Construct a helper for the given descriptor value.
MemRefDescriptor::MemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {
assert(value != nullptr && "value cannot be null");
indexType = value.getType()
.cast<LLVM::LLVMStructType>()
.getBody()[kOffsetPosInMemRefDescriptor];
}
/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
return MemRefDescriptor(descriptor);
}
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
MemRefDescriptor
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
// Extract all strides and offsets and verify they are static.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto result = getStridesAndOffset(type, strides, offset);
(void)result;
assert(succeeded(result) && "unexpected failure in stride computation");
assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
"expected static offset");
assert(!llvm::any_of(strides, [](int64_t stride) {
return MemRefType::isDynamicStrideOrOffset(stride);
}) && "expected static strides");
auto convertedType = typeConverter.convertType(type);
assert(convertedType && "unexpected failure in memref type conversion");
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
descr.setAlignedPtr(builder, loc, memory);
descr.setConstantOffset(builder, loc, offset);
// Fill in sizes and strides
for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
descr.setConstantSize(builder, loc, i, type.getDimSize(i));
descr.setConstantStride(builder, loc, i, strides[i]);
}
return descr;
}
/// Builds IR extracting the allocated pointer from the descriptor.
Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value ptr) {
setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value ptr) {
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}
// Creates a constant Op producing a value of `resultType` from an index-typed
// integer attribute.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
/// Builds IR extracting the offset from the descriptor.
Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
uint64_t offset) {
setOffset(builder, loc,
createIndexAttrConstant(builder, loc, indexType, offset));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
int64_t rank) {
auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
// Copy size values to stack-allocated memory.
auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
auto sizes = builder.create<LLVM::ExtractValueOp>(
loc, arrayTy, value,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
auto sizesPtr =
builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
// Load an return size value of interest.
auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
ValueRange({zero, pos}));
return builder.create<LLVM::LoadOp>(loc, resultPtr);
}
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
unsigned pos, uint64_t size) {
setSize(builder, loc, pos,
createIndexAttrConstant(builder, loc, indexType, size));
}
/// Builds IR extracting the pos-th stride from the descriptor.
Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
unsigned pos, uint64_t stride) {
setStride(builder, loc, pos,
createIndexAttrConstant(builder, loc, indexType, stride));
}
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
return value.getType()
.cast<LLVM::LLVMStructType>()
.getBody()[kAlignedPtrPosInMemRefDescriptor]
.cast<LLVM::LLVMPointerType>();
}
/// Creates a MemRef descriptor structure from a list of individual values
/// composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
int64_t rank = type.getRank();
for (unsigned i = 0; i < rank; ++i) {
d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
}
return d;
}
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type,
SmallVectorImpl<Value> &results) {
int64_t rank = type.getRank();
results.reserve(results.size() + getNumUnpackedValues(type));
MemRefDescriptor d(packed);
results.push_back(d.allocatedPtr(builder, loc));
results.push_back(d.alignedPtr(builder, loc));
results.push_back(d.offset(builder, loc));
for (int64_t i = 0; i < rank; ++i)
results.push_back(d.size(builder, loc, i));
for (int64_t i = 0; i < rank; ++i)
results.push_back(d.stride(builder, loc, i));
}
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
// Two pointers, offset, <rank> sizes, <rank> shapes.
return 3 + 2 * type.getRank();
}
//===----------------------------------------------------------------------===//
// MemRefDescriptorView implementation.
//===----------------------------------------------------------------------===//
MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
: rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
Value MemRefDescriptorView::allocatedPtr() {
return elements[kAllocatedPtrPosInMemRefDescriptor];
}
Value MemRefDescriptorView::alignedPtr() {
return elements[kAlignedPtrPosInMemRefDescriptor];
}
Value MemRefDescriptorView::offset() {
return elements[kOffsetPosInMemRefDescriptor];
}
Value MemRefDescriptorView::size(unsigned pos) {
return elements[kSizePosInMemRefDescriptor + pos];
}
Value MemRefDescriptorView::stride(unsigned pos) {
return elements[kSizePosInMemRefDescriptor + rank + pos];
}
//===----------------------------------------------------------------------===//
// UnrankedMemRefDescriptor implementation
//===----------------------------------------------------------------------===//
/// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {}
/// Builds IR creating an `undef` value of the descriptor type.
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
Value v) {
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
}
Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
Location loc) {
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
Location loc, Value v) {
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
}
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter,
UnrankedMemRefType type,
ValueRange values) {
Type llvmType = converter.convertType(type);
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
return d;
}
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
Value packed,
SmallVectorImpl<Value> &results) {
UnrankedMemRefDescriptor d(packed);
results.reserve(results.size() + 2);
results.push_back(d.rank(builder, loc));
results.push_back(d.memRefDescPtr(builder, loc));
}
void UnrankedMemRefDescriptor::computeSizes(
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
if (values.empty())
return;
// Cache the index type.
Type indexType = typeConverter.getIndexType();
// Initialize shared constants.
Value one = createIndexAttrConstant(builder, loc, indexType, 1);
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
Value pointerSize = createIndexAttrConstant(
builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
Value indexSize =
createIndexAttrConstant(builder, loc, indexType,
ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
sizes.reserve(sizes.size() + values.size());
for (UnrankedMemRefDescriptor desc : values) {
// Emit IR computing the memory necessary to store the descriptor. This
// assumes the descriptor to be
// { type*, type*, index, index[rank], index[rank] }
// and densely packed, so the total size is
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
// TODO: consider including the actual size (including eventual padding due
// to data layout) into the unranked descriptor.
Value doublePointerSize =
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
// (1 + 2 * rank) * sizeof(index)
Value rank = desc.rank(builder, loc);
Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
Value doubleRankIncremented =
builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
Value rankIndexSize = builder.create<LLVM::MulOp>(
loc, indexType, doubleRankIncremented, indexSize);
// Total allocation size.
Value allocationSize = builder.create<LLVM::AddOp>(
loc, indexType, doublePointerSize, rankIndexSize);
sizes.push_back(allocationSize);
}
}
Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr,
Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
}
void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value memRefDescPtr,
Type elemPtrPtrType,
Value allocatedPtr) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
}
Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
Value one =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
Value alignedGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
return builder.create<LLVM::LoadOp>(loc, alignedGep);
}
void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
Type elemPtrPtrType,
Value alignedPtr) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
Value one =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
Value alignedGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
}
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
Type elemPtrPtrType) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
Value two =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
return builder.create<LLVM::LoadOp>(loc, offsetGep);
}
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value memRefDescPtr,
Type elemPtrPtrType, Value offset) {
Value elementPtrPtr =
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
Value two =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
Value offsetGep = builder.create<LLVM::GEPOp>(
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
offsetGep = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
}
Value UnrankedMemRefDescriptor::sizeBasePtr(
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
Type elemPtrTy = elemPtrPtrType.getElementType();
Type indexTy = typeConverter.getIndexType();
Type structPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
Type int32_type = typeConverter.convertType(builder.getI32Type());
Value zero =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
builder.getI32IntegerAttr(3));
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
structPtr, ValueRange({zero, three}));
}
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
}
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value sizeBasePtr, Value index,
Value size) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
}
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
Value sizeBasePtr, Value rank) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
ValueRange({rank}));
}
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
}
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
LLVMTypeConverter typeConverter,
Value strideBasePtr, Value index,
Value stride) {
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
Value strideStoreGep = builder.create<LLVM::GEPOp>(
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
}

View File

@ -0,0 +1,25 @@
//===- MemRefDescriptor.h - MemRef descriptor constants ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines constants that are used in LLVM dialect equivalents of MemRef type.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H
#define MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
#endif // MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H

View File

@ -0,0 +1,36 @@
//===- StructBuilder.cpp - Helper for building LLVM structs --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// StructBuilder implementation
//===----------------------------------------------------------------------===//
StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
assert(value != nullptr && "value cannot be null");
assert(LLVM::isCompatibleType(structType) && "expected llvm type");
}
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
Type type = structType.cast<LLVM::LLVMStructType>().getBody()[pos];
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}

View File

@ -0,0 +1,492 @@
//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "MemRefDescriptor.h"
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
using namespace mlir;
/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options,
const DataLayoutAnalysis *analysis)
: llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
dataLayoutAnalysis(analysis) {
assert(llvmDialect && "LLVM IR dialect is not registered");
// Register conversions for the builtin types.
addConversion([&](ComplexType type) { return convertComplexType(type); });
addConversion([&](FloatType type) { return convertFloatType(type); });
addConversion([&](FunctionType type) { return convertFunctionType(type); });
addConversion([&](IndexType type) { return convertIndexType(type); });
addConversion([&](IntegerType type) { return convertIntegerType(type); });
addConversion([&](MemRefType type) { return convertMemRefType(type); });
addConversion(
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
addConversion([&](VectorType type) { return convertVectorType(type); });
// LLVM-compatible types are legal, so add a pass-through conversion.
addConversion([](Type type) {
return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
: llvm::None;
});
// Materialization for memrefs creates descriptor structs from individual
// values constituting them, when descriptors are used, i.e. more than one
// value represents a memref.
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() != 1)
return llvm::None;
// FIXME: These should check LLVM::DialectCastOp can actually be constructed
// from the input and result.
return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
.getResult();
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() != 1)
return llvm::None;
// FIXME: These should check LLVM::DialectCastOp can actually be constructed
// from the input and result.
return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
.getResult();
});
}
/// Returns the MLIR context.
MLIRContext &LLVMTypeConverter::getContext() {
return *getDialect()->getContext();
}
Type LLVMTypeConverter::getIndexType() {
return IntegerType::get(&getContext(), getIndexTypeBitwidth());
}
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
return options.dataLayout.getPointerSizeInBits(addressSpace);
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
return getIndexType();
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
return IntegerType::get(&getContext(), type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
// Convert a `ComplexType` to an LLVM type. The result is a complex number
// struct with entries for the
// 1. real part and for the
// 2. imaginary part.
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
auto elementType = convertType(type.getElementType());
return LLVM::LLVMStructType::getLiteral(&getContext(),
{elementType, elementType});
}
// Except for signatures, MLIR function types are converted into LLVM
// pointer-to-function types.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
Type converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
return LLVM::LLVMPointerType::get(converted);
}
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Select the argument converter depending on the calling convention.
auto funcArgConverter = options.useBarePtrCallConv
? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(funcTy.getInputs())) {
Type type = en.value();
SmallVector<Type, 8> converted;
if (failed(funcArgConverter(*this, type, converted)))
return {};
result.addInputs(en.index(), converted);
}
SmallVector<Type, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
for (Type type : result.getConvertedTypes())
argTypes.push_back(type);
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
Type resultType = funcTy.getNumResults() == 0
? LLVM::LLVMVoidType::get(&getContext())
: packFunctionResults(funcTy.getResults());
if (!resultType)
return {};
return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic);
}
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
std::pair<Type, bool>
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
SmallVector<Type, 4> inputs;
bool resultIsNowArg = false;
Type resultType = type.getNumResults() == 0
? LLVM::LLVMVoidType::get(&getContext())
: packFunctionResults(type.getResults());
if (!resultType)
return {};
if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
// Struct types cannot be safely returned via C interface. Make this a
// pointer argument, instead.
inputs.push_back(LLVM::LLVMPointerType::get(structType));
resultType = LLVM::LLVMVoidType::get(&getContext());
resultIsNowArg = true;
}
for (Type t : type.getInputs()) {
auto converted = convertType(t);
if (!converted || !LLVM::isCompatibleType(converted))
return {};
if (t.isa<MemRefType, UnrankedMemRefType>())
converted = LLVM::LLVMPointerType::get(converted);
inputs.push_back(converted);
}
return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
}
/// Convert a memref type into a list of LLVM IR types that will form the
/// memref descriptor. The result contains the following types:
/// 1. The pointer to the allocated data buffer, followed by
/// 2. The pointer to the aligned data buffer, followed by
/// 3. A lowered `index`-type integer containing the distance between the
/// beginning of the buffer and the first element to be accessed through the
/// view, followed by
/// 4. An array containing as many `index`-type integers as the rank of the
/// MemRef: the array represents the size, in number of elements, of the memref
/// along the given dimension. For constant MemRef dimensions, the
/// corresponding size entry is a constant whose runtime value must match the
/// static value, followed by
/// 5. A second array containing as many `index`-type integers as the rank of
/// the MemRef: the second array represents the "stride" (in tensor abstraction
/// sense), i.e. the number of consecutive elements of the underlying buffer.
/// TODO: add assertions for the static cases.
///
/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
/// are expanded into individual index-type elements.
///
/// template <typename Elem, typename Index, size_t Rank>
/// struct {
/// Elem *allocatedPtr;
/// Elem *alignedPtr;
/// Index offset;
/// Index sizes[Rank]; // omitted when rank == 0
/// Index strides[Rank]; // omitted when rank == 0
/// };
SmallVector<Type, 5>
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates) {
assert(isStrided(type) &&
"Non-strided layout maps must have been normalized away");
Type elementType = convertType(type.getElementType());
if (!elementType)
return {};
auto ptrTy =
LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
auto indexTy = getIndexType();
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
auto rank = type.getRank();
if (rank == 0)
return results;
if (unpackAggregates)
results.insert(results.end(), 2 * rank, indexTy);
else
results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
return results;
}
unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
const DataLayout &layout) {
// Compute the descriptor size given that of its components indicated above.
unsigned space = type.getMemorySpaceAsInt();
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
(1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
}
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
// When converting a MemRefType to a struct with descriptor fields, do not
// unpack the `sizes` and `strides` arrays.
SmallVector<Type, 5> types =
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
if (types.empty())
return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
}
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that will form the unranked memref descriptor. In particular, the fields
/// for an unranked memref descriptor are:
/// 1. index-typed rank, the dynamic rank of this MemRef
/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
/// be unranked.
SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
return {getIndexType(),
LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))};
}
unsigned
LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
const DataLayout &layout) {
// Compute the descriptor size given that of its components indicated above.
unsigned space = type.getMemorySpaceAsInt();
return layout.getTypeSize(getIndexType()) +
llvm::divideCeil(getPointerBitwidth(space), 8);
}
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
if (!convertType(type.getElementType()))
return {};
return LLVM::LLVMStructType::getLiteral(&getContext(),
getUnrankedMemRefDescriptorFields());
}
/// Convert a memref type to a bare pointer to the memref element type.
Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
if (type.isa<UnrankedMemRefType>())
// Unranked memref is not supported in the bare pointer calling convention.
return {};
// Check that the memref has static shape, strides and offset. Otherwise, it
// cannot be lowered to a bare pointer.
auto memrefTy = type.cast<MemRefType>();
if (!memrefTy.hasStaticShape())
return {};
int64_t offset = 0;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memrefTy, strides, offset)))
return {};
for (int64_t stride : strides)
if (ShapedType::isDynamicStrideOrOffset(stride))
return {};
if (ShapedType::isDynamicStrideOrOffset(offset))
return {};
Type elementType = convertType(type.getElementType());
if (!elementType)
return {};
return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
}
/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
/// when n > 1. For example, `vector<4 x f32>` remains as is while,
/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
Type vectorType = VectorType::get(type.getShape().back(), elementType);
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
return vectorType;
}
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
Type LLVMTypeConverter::convertCallingConventionType(Type type) {
if (options.useBarePtrCallConv)
if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
return convertMemRefToBarePtr(memrefTy);
return convertType(type);
}
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void LLVMTypeConverter::promoteBarePtrsToDescriptors(
ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
SmallVectorImpl<Value> &values) {
assert(stdTypes.size() == values.size() &&
"The number of types and values doesn't match");
for (unsigned i = 0, end = values.size(); i < end; ++i)
if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
memrefTy, values[i]);
}
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one value is returned,
/// create an LLVM IR structure type with elements that correspond to each of
/// the MLIR types converted with `convertType`.
Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
return convertCallingConventionType(types.front());
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
auto converted = convertCallingConventionType(t);
if (!converted || !LLVM::isCompatibleType(converted))
return {};
resultTypes.push_back(converted);
}
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
auto int64Ty = IntegerType::get(builder.getContext(), 64);
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value allocated =
builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
// Store into the alloca'ed descriptor.
builder.create<LLVM::StoreOp>(loc, operand, allocated);
return allocated;
}
SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
ValueRange opOperands,
ValueRange operands,
OpBuilder &builder) {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
for (auto it : llvm::zip(opOperands, operands)) {
auto operand = std::get<0>(it);
auto llvmOperand = std::get<1>(it);
if (options.useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
MemRefDescriptor desc(llvmOperand);
llvmOperand = desc.alignedPtr(builder, loc);
} else if (operand.getType().isa<UnrankedMemRefType>()) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (operand.getType().isa<UnrankedMemRefType>()) {
UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
promotedOperands);
continue;
}
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
promotedOperands);
continue;
}
}
promotedOperands.push_back(llvmOperand);
}
return promotedOperands;
}
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
if (auto memref = type.dyn_cast<MemRefType>()) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
auto converted =
converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
if (converted.empty())
return failure();
result.append(converted.begin(), converted.end());
return success();
}
if (type.isa<UnrankedMemRefType>()) {
auto converted = converter.getUnrankedMemRefDescriptorFields();
if (converted.empty())
return failure();
result.append(converted.begin(), converted.end());
return success();
}
auto converted = converter.convertType(type);
if (!converted)
return failure();
result.push_back(converted);
return success();
}
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
auto llvmTy = converter.convertCallingConventionType(type);
if (!llvmTy)
return failure();
result.push_back(llvmTy);
return success();
}

View File

@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRSPIRVToLLVM
MLIRGPUOps MLIRGPUOps
MLIRSPIRV MLIRSPIRV
MLIRSPIRVUtils MLIRSPIRVUtils
MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRIR MLIRIR

View File

@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"

View File

@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAnalysis MLIRAnalysis
MLIRDataLayoutInterfaces MLIRDataLayoutInterfaces
MLIRLLVMCommonConversion
MLIRLLVMIR MLIRLLVMIR
MLIRMath MLIRMath
MLIRMemRef MLIRMemRef

File diff suppressed because it is too large Load Diff

View File

@ -61,6 +61,7 @@ if (MLIR_ENABLE_VULKAN_RUNNER)
MLIRIR MLIRIR
MLIRJitRunner MLIRJitRunner
MLIRLLVMIR MLIRLLVMIR
MLIRLLVMCommonConversion
MLIRLLVMToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation
MLIRMemRef MLIRMemRef
MLIRParser MLIRParser

View File

@ -14,6 +14,7 @@
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"