forked from OSchip/llvm-project
[mlir] factor out ConvertToLLVMPattern
This class and classes that extend it are general utilities for any dialect that is being converted into the LLVM dialect. They are in no way specific to Standard-to-LLVM conversion and should not make their users depend on it. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105542
This commit is contained in:
parent
9ced1e44ad
commit
684dfe8adb
|
@ -0,0 +1,193 @@
|
|||
//===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
|
||||
#define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||
/// and given operands.
|
||||
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. It
|
||||
/// provides the conversion patterns with access to the LLVMTypeConverter and
|
||||
/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
|
||||
/// LowerToLLVMOptions by reference meaning the references have to remain alive
|
||||
/// during the entire pattern lifetime.
|
||||
class ConvertToLLVMPattern : public ConversionPattern {
|
||||
public:
|
||||
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
protected:
|
||||
/// Returns the LLVM dialect.
|
||||
LLVM::LLVMDialect &getDialect() const;
|
||||
|
||||
LLVMTypeConverter *getTypeConverter() const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
|
||||
/// defined by the used type converter.
|
||||
Type getIndexType() const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM integer type whose bit width
|
||||
/// corresponds to that of a LLVM pointer type.
|
||||
Type getIntPtrType(unsigned addressSpace = 0) const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM void type.
|
||||
Type getVoidType() const;
|
||||
|
||||
/// Get the MLIR type wrapping the LLVM i8* type.
|
||||
Type getVoidPtrType() const;
|
||||
|
||||
/// Create a constant Op producing a value of `resultType` from an index-typed
|
||||
/// integer attribute.
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value);
|
||||
|
||||
/// Create an LLVM dialect operation defining the given index constant.
|
||||
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
|
||||
uint64_t value) const;
|
||||
|
||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
|
||||
ValueRange indices,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Returns if the given memref has identity maps and the element type is
|
||||
/// convertible to LLVM.
|
||||
bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
|
||||
|
||||
/// Returns the type of a pointer to an element of the memref.
|
||||
Type getElementPtrType(MemRefType type) const;
|
||||
|
||||
/// Computes sizes, strides and buffer size in bytes of `memRefType` with
|
||||
/// identity layout. Emits constant ops for the static sizes of `memRefType`,
|
||||
/// and uses `dynamicSizes` for the others. Emits instructions to compute
|
||||
/// strides and buffer size from these sizes.
|
||||
///
|
||||
/// For example, memref<4x?xf32> emits:
|
||||
/// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
|
||||
/// `sizes[1]` = `dynamicSizes[0]`
|
||||
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
|
||||
/// `strides[0]` = `sizes[0]`
|
||||
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
|
||||
/// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
|
||||
/// %gep = llvm.getelementptr %nullptr[%size]
|
||||
/// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
/// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
|
||||
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
|
||||
ValueRange dynamicSizes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVectorImpl<Value> &sizes,
|
||||
SmallVectorImpl<Value> &strides,
|
||||
Value &sizeBytes) const;
|
||||
|
||||
/// Computes the size of type in bytes.
|
||||
Value getSizeInBytes(Location loc, Type type,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Computes total number of elements for the given shape.
|
||||
Value getNumElements(Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Creates and populates a canonical memref descriptor struct.
|
||||
MemRefDescriptor
|
||||
createMemRefDescriptor(Location loc, MemRefType memRefType,
|
||||
Value allocatedPtr, Value alignedPtr,
|
||||
ArrayRef<Value> sizes, ArrayRef<Value> strides,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
/// Utility class for operation conversions targeting the LLVM dialect that
|
||||
/// match exactly one source operation.
|
||||
template <typename SourceOp>
|
||||
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertToLLVMPattern(SourceOp::getOperationName(),
|
||||
&typeConverter.getContext(), typeConverter,
|
||||
benefit) {}
|
||||
|
||||
/// Wrappers around the RewritePattern methods that pass the derived op type.
|
||||
void rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
}
|
||||
LogicalResult match(Operation *op) const final {
|
||||
return match(cast<SourceOp>(op));
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
}
|
||||
|
||||
/// Rewrite and Match methods that operate on the SourceOp type. These must be
|
||||
/// overridden by the derived pattern class.
|
||||
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
llvm_unreachable("must override rewrite or matchAndRewrite");
|
||||
}
|
||||
virtual LogicalResult match(SourceOp op) const {
|
||||
llvm_unreachable("must override match or matchAndRewrite");
|
||||
}
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (succeeded(match(op))) {
|
||||
rewrite(op, operands, rewriter);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
using ConvertToLLVMPattern::match;
|
||||
using ConvertToLLVMPattern::matchAndRewrite;
|
||||
};
|
||||
|
||||
/// Generic implementation of one-to-one conversion from "SourceOp" to
|
||||
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
|
||||
/// Upholds a convention that multi-result operations get converted into an
|
||||
/// operation returning the LLVM IR structure type, in which case individual
|
||||
/// values must be extracted from using LLVM::ExtractValueOp before being used.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
|
||||
|
||||
/// Converts the type of the result to an LLVM type, pass operands as is,
|
||||
/// preserve attributes.
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
|
||||
operands, *this->getTypeConverter(),
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
|
|
@ -0,0 +1,85 @@
|
|||
//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
|
||||
#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
|
||||
// 1-D LLVM vectors.
|
||||
struct NDVectorTypeInfo {
|
||||
// LLVM array struct which encodes n-D vectors.
|
||||
Type llvmNDVectorTy;
|
||||
// LLVM vector type which encodes the inner 1-D vector type.
|
||||
Type llvm1DVectorTy;
|
||||
// Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
|
||||
SmallVector<int64_t, 4> arraySizes;
|
||||
};
|
||||
|
||||
// For >1-D vector types, extracts the necessary information to iterate over all
|
||||
// 1-D subvectors in the underlying llrepresentation of the n-D vector
|
||||
// Iterates on the llvm array type until we hit a non-array type (which is
|
||||
// asserted to be an llvm vector type).
|
||||
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
|
||||
LLVMTypeConverter &converter);
|
||||
|
||||
// Express `linearIndex` in terms of coordinates of `basis`.
|
||||
// Returns the empty vector when linearIndex is out of the range [0, P] where
|
||||
// P is the product of all the basis coordinates.
|
||||
//
|
||||
// Prerequisites:
|
||||
// Basis is an array of nonnegative integers (signed type inherited from
|
||||
// vector shape type).
|
||||
SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
|
||||
unsigned linearIndex);
|
||||
|
||||
// Iterate of linear index, convert to coords space and insert splatted 1-D
|
||||
// vector in each position.
|
||||
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
|
||||
function_ref<void(ArrayAttr)> fun);
|
||||
|
||||
LogicalResult handleMultidimensionalVectors(
|
||||
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
|
||||
std::function<Value(Type, ValueRange)> createOperand,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
|
||||
/// Basic lowering implementation to rewrite Ops with just one result to the
|
||||
/// LLVM Dialect. This supports higher-dimensional vector types.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
return LLVM::detail::vectorOneToOneRewrite(
|
||||
op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
|
|
@ -15,167 +15,37 @@
|
|||
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
|
||||
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace llvm {
|
||||
class IntegerType;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class Type;
|
||||
} // namespace llvm
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class BaseMemRefType;
|
||||
class ComplexType;
|
||||
class DataLayoutAnalysis;
|
||||
class LLVMTypeConverter;
|
||||
class UnrankedMemRefType;
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
class LLVMPointerType;
|
||||
} // namespace LLVM
|
||||
/// Collect a set of patterns to convert memory-related operations from the
|
||||
/// Standard dialect to the LLVM dialect, excluding non-memory-related
|
||||
/// operations and FuncOp.
|
||||
void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
// ------------------
|
||||
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
|
||||
/// dialect, excluding the memory-related operations.
|
||||
void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Base class for operation conversions targeting the LLVM IR dialect. It
|
||||
/// provides the conversion patterns with access to the LLVMTypeConverter and
|
||||
/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
|
||||
/// LowerToLLVMOptions by reference meaning the references have to remain alive
|
||||
/// during the entire pattern lifetime.
|
||||
class ConvertToLLVMPattern : public ConversionPattern {
|
||||
public:
|
||||
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1);
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
protected:
|
||||
/// Returns the LLVM dialect.
|
||||
LLVM::LLVMDialect &getDialect() const;
|
||||
|
||||
LLVMTypeConverter *getTypeConverter() const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
|
||||
/// defined by the used type converter.
|
||||
Type getIndexType() const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM integer type whose bit width
|
||||
/// corresponds to that of a LLVM pointer type.
|
||||
Type getIntPtrType(unsigned addressSpace = 0) const;
|
||||
|
||||
/// Gets the MLIR type wrapping the LLVM void type.
|
||||
Type getVoidType() const;
|
||||
|
||||
/// Get the MLIR type wrapping the LLVM i8* type.
|
||||
Type getVoidPtrType() const;
|
||||
|
||||
/// Create an LLVM dialect operation defining the given index constant.
|
||||
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
|
||||
uint64_t value) const;
|
||||
|
||||
// This is a strided getElementPtr variant that linearizes subscripts as:
|
||||
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
|
||||
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
|
||||
ValueRange indices,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Returns if the given memref has identity maps and the element type is
|
||||
/// convertible to LLVM.
|
||||
bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
|
||||
|
||||
/// Returns the type of a pointer to an element of the memref.
|
||||
Type getElementPtrType(MemRefType type) const;
|
||||
|
||||
/// Computes sizes, strides and buffer size in bytes of `memRefType` with
|
||||
/// identity layout. Emits constant ops for the static sizes of `memRefType`,
|
||||
/// and uses `dynamicSizes` for the others. Emits instructions to compute
|
||||
/// strides and buffer size from these sizes.
|
||||
///
|
||||
/// For example, memref<4x?xf32> emits:
|
||||
/// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
|
||||
/// `sizes[1]` = `dynamicSizes[0]`
|
||||
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
|
||||
/// `strides[0]` = `sizes[0]`
|
||||
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
|
||||
/// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
|
||||
/// %gep = llvm.getelementptr %nullptr[%size]
|
||||
/// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
/// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
|
||||
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
|
||||
ValueRange dynamicSizes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVectorImpl<Value> &sizes,
|
||||
SmallVectorImpl<Value> &strides,
|
||||
Value &sizeBytes) const;
|
||||
|
||||
/// Computes the size of type in bytes.
|
||||
Value getSizeInBytes(Location loc, Type type,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Computes total number of elements for the given shape.
|
||||
Value getNumElements(Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Creates and populates a canonical memref descriptor struct.
|
||||
MemRefDescriptor
|
||||
createMemRefDescriptor(Location loc, MemRefType memRefType,
|
||||
Value allocatedPtr, Value alignedPtr,
|
||||
ArrayRef<Value> sizes, ArrayRef<Value> strides,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
/// Utility class for operation conversions targeting the LLVM dialect that
|
||||
/// match exactly one source operation.
|
||||
template <typename SourceOp>
|
||||
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertToLLVMPattern(SourceOp::getOperationName(),
|
||||
&typeConverter.getContext(), typeConverter,
|
||||
benefit) {}
|
||||
|
||||
/// Wrappers around the RewritePattern methods that pass the derived op type.
|
||||
void rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
}
|
||||
LogicalResult match(Operation *op) const final {
|
||||
return match(cast<SourceOp>(op));
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
}
|
||||
|
||||
/// Rewrite and Match methods that operate on the SourceOp type. These must be
|
||||
/// overridden by the derived pattern class.
|
||||
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
llvm_unreachable("must override rewrite or matchAndRewrite");
|
||||
}
|
||||
virtual LogicalResult match(SourceOp op) const {
|
||||
llvm_unreachable("must override match or matchAndRewrite");
|
||||
}
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (succeeded(match(op))) {
|
||||
rewrite(op, operands, rewriter);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
using ConvertToLLVMPattern::match;
|
||||
using ConvertToLLVMPattern::matchAndRewrite;
|
||||
};
|
||||
/// Collect the patterns to convert from the Standard dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
|
||||
/// by reference meaning the references have to remain alive during the entire
|
||||
/// pattern lifetime.
|
||||
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Lowering for AllocOp and AllocaOp.
|
||||
struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
|
||||
|
@ -226,64 +96,6 @@ private:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
namespace LLVM {
|
||||
namespace detail {
|
||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||
/// and given operands.
|
||||
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
||||
ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
|
||||
/// Generic implementation of one-to-one conversion from "SourceOp" to
|
||||
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
|
||||
/// Upholds a convention that multi-result operations get converted into an
|
||||
/// operation returning the LLVM IR structure type, in which case individual
|
||||
/// values must be extracted from using LLVM::ExtractValueOp before being used.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
|
||||
|
||||
/// Converts the type of the result to an LLVM type, pass operands as is,
|
||||
/// preserve attributes.
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
|
||||
operands, *this->getTypeConverter(),
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
/// Basic lowering implementation to rewrite Ops with just one result to the
|
||||
/// LLVM Dialect. This supports higher-dimensional vector types.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
return LLVM::detail::vectorOneToOneRewrite(
|
||||
op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
/// Derived class that automatically populates legalization information for
|
||||
/// different LLVM ops.
|
||||
class LLVMConversionTarget : public ConversionTarget {
|
||||
|
|
|
@ -12,38 +12,10 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class LLVMTypeConverter;
|
||||
class LowerToLLVMOptions;
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
class RewritePatternSet;
|
||||
using OwningRewritePatternList = RewritePatternSet;
|
||||
|
||||
/// Collect a set of patterns to convert memory-related operations from the
|
||||
/// Standard dialect to the LLVM dialect, excluding non-memory-related
|
||||
/// operations and FuncOp.
|
||||
void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
|
||||
/// dialect, excluding the memory-related operations.
|
||||
void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
|
||||
/// `emitCWrappers` is set, the pattern will also produce functions
|
||||
/// that pass memref descriptors by pointer-to-structure in addition to the
|
||||
/// default unpacked form.
|
||||
void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect the patterns to convert from the Standard dialect to LLVM. The
|
||||
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
|
||||
/// by reference meaning the references have to remain alive during the entire
|
||||
/// pattern lifetime.
|
||||
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
||||
/// stdlib malloc/free is used by default for allocating memrefs allocated with
|
||||
|
|
|
@ -32,6 +32,7 @@ add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms
|
|||
MLIRAsyncToLLVM
|
||||
MLIRGPUTransforms
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
MLIRPass
|
||||
MLIRSupport
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
add_mlir_conversion_library(MLIRLLVMCommonConversion
|
||||
LoweringOptions.cpp
|
||||
MemRefBuilder.cpp
|
||||
Pattern.cpp
|
||||
StructBuilder.cpp
|
||||
TypeConverter.cpp
|
||||
VectorPattern.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
|
|
@ -0,0 +1,269 @@
|
|||
//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertToLLVMPattern
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
|
||||
MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit)
|
||||
: ConversionPattern(typeConverter, rootOpName, benefit, context) {}
|
||||
|
||||
LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
|
||||
return static_cast<LLVMTypeConverter *>(
|
||||
ConversionPattern::getTypeConverter());
|
||||
}
|
||||
|
||||
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
|
||||
return *getTypeConverter()->getDialect();
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getIndexType() const {
|
||||
return getTypeConverter()->getIndexType();
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
|
||||
return IntegerType::get(&getTypeConverter()->getContext(),
|
||||
getTypeConverter()->getPointerBitwidth(addressSpace));
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getVoidType() const {
|
||||
return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getVoidPtrType() const {
|
||||
return LLVM::LLVMPointerType::get(
|
||||
IntegerType::get(&getTypeConverter()->getContext(), 8));
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
|
||||
Location loc,
|
||||
Type resultType,
|
||||
int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::createIndexConstant(
|
||||
ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
|
||||
return createIndexAttrConstant(builder, loc, getIndexType(), value);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getStridedElementPtr(
|
||||
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memRefDesc);
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
|
||||
Value index;
|
||||
if (offset != 0) // Skip if offset is zero.
|
||||
index = MemRefType::isDynamicStrideOrOffset(offset)
|
||||
? memRefDescriptor.offset(rewriter, loc)
|
||||
: createIndexConstant(rewriter, loc, offset);
|
||||
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value increment = indices[i];
|
||||
if (strides[i] != 1) { // Skip if stride is 1.
|
||||
Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: createIndexConstant(rewriter, loc, strides[i]);
|
||||
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
|
||||
}
|
||||
index =
|
||||
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
|
||||
}
|
||||
|
||||
Type elementPtrType = memRefDescriptor.getElementPtrType();
|
||||
return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
|
||||
: base;
|
||||
}
|
||||
|
||||
// Check if the MemRefType `type` is supported by the lowering. We currently
|
||||
// only support memrefs with identity maps.
|
||||
bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
|
||||
MemRefType type) const {
|
||||
if (!typeConverter->convertType(type.getElementType()))
|
||||
return false;
|
||||
return type.getAffineMaps().empty() ||
|
||||
llvm::all_of(type.getAffineMaps(),
|
||||
[](AffineMap map) { return map.isIdentity(); });
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
|
||||
auto elementType = type.getElementType();
|
||||
auto structElementType = typeConverter->convertType(elementType);
|
||||
return LLVM::LLVMPointerType::get(structElementType,
|
||||
type.getMemorySpaceAsInt());
|
||||
}
|
||||
|
||||
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
||||
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
|
||||
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
|
||||
SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
|
||||
assert(isConvertibleAndHasIdentityMaps(memRefType) &&
|
||||
"layout maps must have been normalized away");
|
||||
assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
|
||||
static_cast<ssize_t>(dynamicSizes.size()) &&
|
||||
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
|
||||
|
||||
sizes.reserve(memRefType.getRank());
|
||||
unsigned dynamicIndex = 0;
|
||||
for (int64_t size : memRefType.getShape()) {
|
||||
sizes.push_back(size == ShapedType::kDynamicSize
|
||||
? dynamicSizes[dynamicIndex++]
|
||||
: createIndexConstant(rewriter, loc, size));
|
||||
}
|
||||
|
||||
// Strides: iterate sizes in reverse order and multiply.
|
||||
int64_t stride = 1;
|
||||
Value runningStride = createIndexConstant(rewriter, loc, 1);
|
||||
strides.resize(memRefType.getRank());
|
||||
for (auto i = memRefType.getRank(); i-- > 0;) {
|
||||
strides[i] = runningStride;
|
||||
|
||||
int64_t size = memRefType.getShape()[i];
|
||||
if (size == 0)
|
||||
continue;
|
||||
bool useSizeAsStride = stride == 1;
|
||||
if (size == ShapedType::kDynamicSize)
|
||||
stride = ShapedType::kDynamicSize;
|
||||
if (stride != ShapedType::kDynamicSize)
|
||||
stride *= size;
|
||||
|
||||
if (useSizeAsStride)
|
||||
runningStride = sizes[i];
|
||||
else if (stride == ShapedType::kDynamicSize)
|
||||
runningStride =
|
||||
rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
|
||||
else
|
||||
runningStride = createIndexConstant(rewriter, loc, stride);
|
||||
}
|
||||
|
||||
// Buffer size in bytes.
|
||||
Type elementPtrType = getElementPtrType(memRefType);
|
||||
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
|
||||
Value gepPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
|
||||
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getSizeInBytes(
|
||||
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the size of an individual element. This emits the MLIR equivalent
|
||||
// of the following sizeof(...) implementation in LLVM IR:
|
||||
// %0 = getelementptr %elementType* null, %indexType 1
|
||||
// %1 = ptrtoint %elementType* %0 to %indexType
|
||||
// which is a common pattern of getting the size of a type in bytes.
|
||||
auto convertedPtrType =
|
||||
LLVM::LLVMPointerType::get(typeConverter->convertType(type));
|
||||
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, convertedPtrType,
|
||||
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
|
||||
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getNumElements(
|
||||
Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the total number of memref elements.
|
||||
Value numElements =
|
||||
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
|
||||
for (unsigned i = 1, e = shape.size(); i < e; ++i)
|
||||
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
|
||||
return numElements;
|
||||
}
|
||||
|
||||
/// Creates and populates the memref descriptor struct given all its fields.
|
||||
MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
|
||||
Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
|
||||
ArrayRef<Value> sizes, ArrayRef<Value> strides,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto structType = typeConverter->convertType(memRefType);
|
||||
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
|
||||
|
||||
// Field 1: Allocated pointer, used for malloc/free.
|
||||
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
||||
|
||||
// Field 2: Actual aligned pointer to payload.
|
||||
memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
|
||||
|
||||
// Field 3: Offset in aligned pointer.
|
||||
memRefDescriptor.setOffset(rewriter, loc,
|
||||
createIndexConstant(rewriter, loc, 0));
|
||||
|
||||
// Fields 4: Sizes.
|
||||
for (auto en : llvm::enumerate(sizes))
|
||||
memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
|
||||
|
||||
// Field 5: Strides.
|
||||
for (auto en : llvm::enumerate(strides))
|
||||
memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
|
||||
|
||||
return memRefDescriptor;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Detail methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||
/// and given operands.
|
||||
LogicalResult LLVM::detail::oneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
unsigned numResults = op->getNumResults();
|
||||
|
||||
Type packedType;
|
||||
if (numResults != 0) {
|
||||
packedType = typeConverter.packFunctionResults(op->getResultTypes());
|
||||
if (!packedType)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create the operation through state since we don't know its C++ type.
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(packedType);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
return rewriter.eraseOp(op), success();
|
||||
if (numResults == 1)
|
||||
return rewriter.replaceOp(op, newOp->getResult(0)), success();
|
||||
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
SmallVector<Value, 4> results;
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
auto type = typeConverter.convertType(op->getResult(i).getType());
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
|
@ -0,0 +1,142 @@
|
|||
//===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// For >1-D vector types, extracts the necessary information to iterate over all
|
||||
// 1-D subvectors in the underlying llrepresentation of the n-D vector
|
||||
// Iterates on the llvm array type until we hit a non-array type (which is
|
||||
// asserted to be an llvm vector type).
|
||||
LLVM::detail::NDVectorTypeInfo
|
||||
LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
|
||||
LLVMTypeConverter &converter) {
|
||||
assert(vectorType.getRank() > 1 && "expected >1D vector type");
|
||||
NDVectorTypeInfo info;
|
||||
info.llvmNDVectorTy = converter.convertType(vectorType);
|
||||
if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
|
||||
info.llvmNDVectorTy = nullptr;
|
||||
return info;
|
||||
}
|
||||
info.arraySizes.reserve(vectorType.getRank() - 1);
|
||||
auto llvmTy = info.llvmNDVectorTy;
|
||||
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
|
||||
info.arraySizes.push_back(
|
||||
llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
|
||||
llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
|
||||
}
|
||||
if (!LLVM::isCompatibleVectorType(llvmTy))
|
||||
return info;
|
||||
info.llvm1DVectorTy = llvmTy;
|
||||
return info;
|
||||
}
|
||||
|
||||
// Express `linearIndex` in terms of coordinates of `basis`.
|
||||
// Returns the empty vector when linearIndex is out of the range [0, P] where
|
||||
// P is the product of all the basis coordinates.
|
||||
//
|
||||
// Prerequisites:
|
||||
// Basis is an array of nonnegative integers (signed type inherited from
|
||||
// vector shape type).
|
||||
SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
|
||||
unsigned linearIndex) {
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(basis.size());
|
||||
for (unsigned basisElement : llvm::reverse(basis)) {
|
||||
res.push_back(linearIndex % basisElement);
|
||||
linearIndex = linearIndex / basisElement;
|
||||
}
|
||||
if (linearIndex > 0)
|
||||
return {};
|
||||
std::reverse(res.begin(), res.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Iterate of linear index, convert to coords space and insert splatted 1-D
|
||||
// vector in each position.
|
||||
void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
|
||||
OpBuilder &builder,
|
||||
function_ref<void(ArrayAttr)> fun) {
|
||||
unsigned ub = 1;
|
||||
for (auto s : info.arraySizes)
|
||||
ub *= s;
|
||||
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
|
||||
auto coords = getCoordinates(info.arraySizes, linearIndex);
|
||||
// Linear index is out of bounds, we are done.
|
||||
if (coords.empty())
|
||||
break;
|
||||
assert(coords.size() == info.arraySizes.size());
|
||||
auto position = builder.getI64ArrayAttr(coords);
|
||||
fun(position);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult LLVM::detail::handleMultidimensionalVectors(
|
||||
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
|
||||
std::function<Value(Type, ValueRange)> createOperand,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
|
||||
|
||||
SmallVector<Type> operand1DVectorTypes;
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto operandNDVectorType = operand.getType().cast<VectorType>();
|
||||
auto operandTypeInfo =
|
||||
extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
|
||||
operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
|
||||
}
|
||||
auto resultTypeInfo =
|
||||
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
|
||||
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
|
||||
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
|
||||
auto loc = op->getLoc();
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
|
||||
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
SmallVector<Value, 4> extractedOperands;
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, operand1DVectorTypes[operand.index()], operand.value(),
|
||||
position));
|
||||
}
|
||||
Value newVal = createOperand(result1DVectorTy, extractedOperands);
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
|
||||
newVal, position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
assert(!operands.empty());
|
||||
|
||||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
if (!llvm::all_of(operands.getTypes(),
|
||||
[](Type t) { return isCompatibleType(t); }))
|
||||
return failure();
|
||||
|
||||
auto llvmNDVectorTy = operands[0].getType();
|
||||
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
|
||||
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
|
||||
|
||||
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
|
||||
ValueRange operands) {
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(llvm1DVectorTy);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
return rewriter.createOperation(state)->getResult(0);
|
||||
};
|
||||
|
||||
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
||||
rewriter);
|
||||
}
|
|
@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIROpenMPToLLVM
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
MLIROpenMP
|
||||
MLIRStandardToLLVM
|
||||
|
|
|
@ -9,7 +9,9 @@
|
|||
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
|
@ -46,214 +48,6 @@ using namespace mlir;
|
|||
|
||||
#define PASS_NAME "convert-std-to-llvm"
|
||||
|
||||
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
|
||||
MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit)
|
||||
: ConversionPattern(typeConverter, rootOpName, benefit, context) {}
|
||||
|
||||
|
||||
LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
|
||||
return static_cast<LLVMTypeConverter *>(
|
||||
ConversionPattern::getTypeConverter());
|
||||
}
|
||||
|
||||
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
|
||||
return *getTypeConverter()->getDialect();
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getIndexType() const {
|
||||
return getTypeConverter()->getIndexType();
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
|
||||
return IntegerType::get(&getTypeConverter()->getContext(),
|
||||
getTypeConverter()->getPointerBitwidth(addressSpace));
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getVoidType() const {
|
||||
return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getVoidPtrType() const {
|
||||
return LLVM::LLVMPointerType::get(
|
||||
IntegerType::get(&getTypeConverter()->getContext(), 8));
|
||||
}
|
||||
|
||||
// Creates a constant Op producing a value of `resultType` from an index-typed
|
||||
// integer attribute.
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::createIndexConstant(
|
||||
ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
|
||||
return createIndexAttrConstant(builder, loc, getIndexType(), value);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getStridedElementPtr(
|
||||
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
||||
(void)successStrides;
|
||||
|
||||
MemRefDescriptor memRefDescriptor(memRefDesc);
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
|
||||
Value index;
|
||||
if (offset != 0) // Skip if offset is zero.
|
||||
index = MemRefType::isDynamicStrideOrOffset(offset)
|
||||
? memRefDescriptor.offset(rewriter, loc)
|
||||
: createIndexConstant(rewriter, loc, offset);
|
||||
|
||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||
Value increment = indices[i];
|
||||
if (strides[i] != 1) { // Skip if stride is 1.
|
||||
Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
|
||||
? memRefDescriptor.stride(rewriter, loc, i)
|
||||
: createIndexConstant(rewriter, loc, strides[i]);
|
||||
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
|
||||
}
|
||||
index =
|
||||
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
|
||||
}
|
||||
|
||||
Type elementPtrType = memRefDescriptor.getElementPtrType();
|
||||
return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
|
||||
: base;
|
||||
}
|
||||
|
||||
// Check if the MemRefType `type` is supported by the lowering. We currently
|
||||
// only support memrefs with identity maps.
|
||||
bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
|
||||
MemRefType type) const {
|
||||
if (!typeConverter->convertType(type.getElementType()))
|
||||
return false;
|
||||
return type.getAffineMaps().empty() ||
|
||||
llvm::all_of(type.getAffineMaps(),
|
||||
[](AffineMap map) { return map.isIdentity(); });
|
||||
}
|
||||
|
||||
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
|
||||
auto elementType = type.getElementType();
|
||||
auto structElementType = typeConverter->convertType(elementType);
|
||||
return LLVM::LLVMPointerType::get(structElementType,
|
||||
type.getMemorySpaceAsInt());
|
||||
}
|
||||
|
||||
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
||||
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
|
||||
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
|
||||
SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
|
||||
assert(isConvertibleAndHasIdentityMaps(memRefType) &&
|
||||
"layout maps must have been normalized away");
|
||||
assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
|
||||
static_cast<ssize_t>(dynamicSizes.size()) &&
|
||||
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
|
||||
|
||||
sizes.reserve(memRefType.getRank());
|
||||
unsigned dynamicIndex = 0;
|
||||
for (int64_t size : memRefType.getShape()) {
|
||||
sizes.push_back(size == ShapedType::kDynamicSize
|
||||
? dynamicSizes[dynamicIndex++]
|
||||
: createIndexConstant(rewriter, loc, size));
|
||||
}
|
||||
|
||||
// Strides: iterate sizes in reverse order and multiply.
|
||||
int64_t stride = 1;
|
||||
Value runningStride = createIndexConstant(rewriter, loc, 1);
|
||||
strides.resize(memRefType.getRank());
|
||||
for (auto i = memRefType.getRank(); i-- > 0;) {
|
||||
strides[i] = runningStride;
|
||||
|
||||
int64_t size = memRefType.getShape()[i];
|
||||
if (size == 0)
|
||||
continue;
|
||||
bool useSizeAsStride = stride == 1;
|
||||
if (size == ShapedType::kDynamicSize)
|
||||
stride = ShapedType::kDynamicSize;
|
||||
if (stride != ShapedType::kDynamicSize)
|
||||
stride *= size;
|
||||
|
||||
if (useSizeAsStride)
|
||||
runningStride = sizes[i];
|
||||
else if (stride == ShapedType::kDynamicSize)
|
||||
runningStride =
|
||||
rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
|
||||
else
|
||||
runningStride = createIndexConstant(rewriter, loc, stride);
|
||||
}
|
||||
|
||||
// Buffer size in bytes.
|
||||
Type elementPtrType = getElementPtrType(memRefType);
|
||||
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
|
||||
Value gepPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
|
||||
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getSizeInBytes(
|
||||
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the size of an individual element. This emits the MLIR equivalent
|
||||
// of the following sizeof(...) implementation in LLVM IR:
|
||||
// %0 = getelementptr %elementType* null, %indexType 1
|
||||
// %1 = ptrtoint %elementType* %0 to %indexType
|
||||
// which is a common pattern of getting the size of a type in bytes.
|
||||
auto convertedPtrType =
|
||||
LLVM::LLVMPointerType::get(typeConverter->convertType(type));
|
||||
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, convertedPtrType,
|
||||
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
|
||||
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getNumElements(
|
||||
Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the total number of memref elements.
|
||||
Value numElements =
|
||||
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
|
||||
for (unsigned i = 1, e = shape.size(); i < e; ++i)
|
||||
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
|
||||
return numElements;
|
||||
}
|
||||
|
||||
/// Creates and populates the memref descriptor struct given all its fields.
|
||||
MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
|
||||
Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
|
||||
ArrayRef<Value> sizes, ArrayRef<Value> strides,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto structType = typeConverter->convertType(memRefType);
|
||||
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
|
||||
|
||||
// Field 1: Allocated pointer, used for malloc/free.
|
||||
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
||||
|
||||
// Field 2: Actual aligned pointer to payload.
|
||||
memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
|
||||
|
||||
// Field 3: Offset in aligned pointer.
|
||||
memRefDescriptor.setOffset(rewriter, loc,
|
||||
createIndexConstant(rewriter, loc, 0));
|
||||
|
||||
// Fields 4: Sizes.
|
||||
for (auto en : llvm::enumerate(sizes))
|
||||
memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
|
||||
|
||||
// Field 5: Strides.
|
||||
for (auto en : llvm::enumerate(strides))
|
||||
memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
|
||||
|
||||
return memRefDescriptor;
|
||||
}
|
||||
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
|
@ -572,190 +366,6 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|||
}
|
||||
};
|
||||
|
||||
//////////////// Support for Lowering operations on n-D vectors ////////////////
|
||||
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
|
||||
// 1-D LLVM vectors.
|
||||
struct NDVectorTypeInfo {
|
||||
// LLVM array struct which encodes n-D vectors.
|
||||
Type llvmNDVectorTy;
|
||||
// LLVM vector type which encodes the inner 1-D vector type.
|
||||
Type llvm1DVectorTy;
|
||||
// Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
|
||||
SmallVector<int64_t, 4> arraySizes;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// For >1-D vector types, extracts the necessary information to iterate over all
|
||||
// 1-D subvectors in the underlying llrepresentation of the n-D vector
|
||||
// Iterates on the llvm array type until we hit a non-array type (which is
|
||||
// asserted to be an llvm vector type).
|
||||
static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
|
||||
LLVMTypeConverter &converter) {
|
||||
assert(vectorType.getRank() > 1 && "expected >1D vector type");
|
||||
NDVectorTypeInfo info;
|
||||
info.llvmNDVectorTy = converter.convertType(vectorType);
|
||||
if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
|
||||
info.llvmNDVectorTy = nullptr;
|
||||
return info;
|
||||
}
|
||||
info.arraySizes.reserve(vectorType.getRank() - 1);
|
||||
auto llvmTy = info.llvmNDVectorTy;
|
||||
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
|
||||
info.arraySizes.push_back(
|
||||
llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
|
||||
llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
|
||||
}
|
||||
if (!LLVM::isCompatibleVectorType(llvmTy))
|
||||
return info;
|
||||
info.llvm1DVectorTy = llvmTy;
|
||||
return info;
|
||||
}
|
||||
|
||||
// Express `linearIndex` in terms of coordinates of `basis`.
|
||||
// Returns the empty vector when linearIndex is out of the range [0, P] where
|
||||
// P is the product of all the basis coordinates.
|
||||
//
|
||||
// Prerequisites:
|
||||
// Basis is an array of nonnegative integers (signed type inherited from
|
||||
// vector shape type).
|
||||
static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
|
||||
unsigned linearIndex) {
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(basis.size());
|
||||
for (unsigned basisElement : llvm::reverse(basis)) {
|
||||
res.push_back(linearIndex % basisElement);
|
||||
linearIndex = linearIndex / basisElement;
|
||||
}
|
||||
if (linearIndex > 0)
|
||||
return {};
|
||||
std::reverse(res.begin(), res.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Iterate of linear index, convert to coords space and insert splatted 1-D
|
||||
// vector in each position.
|
||||
template <typename Lambda>
|
||||
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
|
||||
Lambda fun) {
|
||||
unsigned ub = 1;
|
||||
for (auto s : info.arraySizes)
|
||||
ub *= s;
|
||||
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
|
||||
auto coords = getCoordinates(info.arraySizes, linearIndex);
|
||||
// Linear index is out of bounds, we are done.
|
||||
if (coords.empty())
|
||||
break;
|
||||
assert(coords.size() == info.arraySizes.size());
|
||||
auto position = builder.getI64ArrayAttr(coords);
|
||||
fun(position);
|
||||
}
|
||||
}
|
||||
////////////// End Support for Lowering operations on n-D vectors //////////////
|
||||
|
||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||
/// and given operands.
|
||||
LogicalResult LLVM::detail::oneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
unsigned numResults = op->getNumResults();
|
||||
|
||||
Type packedType;
|
||||
if (numResults != 0) {
|
||||
packedType = typeConverter.packFunctionResults(op->getResultTypes());
|
||||
if (!packedType)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create the operation through state since we don't know its C++ type.
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(packedType);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
return rewriter.eraseOp(op), success();
|
||||
if (numResults == 1)
|
||||
return rewriter.replaceOp(op, newOp->getResult(0)), success();
|
||||
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
SmallVector<Value, 4> results;
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
auto type = typeConverter.convertType(op->getResult(i).getType());
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult handleMultidimensionalVectors(
|
||||
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
|
||||
std::function<Value(Type, ValueRange)> createOperand,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
|
||||
|
||||
SmallVector<Type> operand1DVectorTypes;
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto operandNDVectorType = operand.getType().cast<VectorType>();
|
||||
auto operandTypeInfo =
|
||||
extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
|
||||
operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
|
||||
}
|
||||
auto resultTypeInfo =
|
||||
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
|
||||
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
|
||||
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
|
||||
auto loc = op->getLoc();
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
|
||||
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
SmallVector<Value, 4> extractedOperands;
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, operand1DVectorTypes[operand.index()], operand.value(),
|
||||
position));
|
||||
}
|
||||
Value newVal = createOperand(result1DVectorTy, extractedOperands);
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
|
||||
newVal, position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
||||
Operation *op, StringRef targetOp, ValueRange operands,
|
||||
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
||||
assert(!operands.empty());
|
||||
|
||||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
if (!llvm::all_of(operands.getTypes(),
|
||||
[](Type t) { return isCompatibleType(t); }))
|
||||
return failure();
|
||||
|
||||
auto llvmNDVectorTy = operands[0].getType();
|
||||
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
|
||||
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
|
||||
|
||||
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
|
||||
ValueRange operands) {
|
||||
OperationState state(op->getLoc(), targetOp);
|
||||
state.addTypes(llvm1DVectorTy);
|
||||
state.addOperands(operands);
|
||||
state.addAttributes(op->getAttrs());
|
||||
return rewriter.createOperation(state)->getResult(0);
|
||||
};
|
||||
|
||||
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Straightforward lowerings.
|
||||
using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
|
||||
using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
|
||||
|
@ -1427,7 +1037,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
|
|||
if (!vectorType)
|
||||
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
op.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
|
@ -1482,7 +1092,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
|
|||
if (!vectorType)
|
||||
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
op.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
|
@ -1536,7 +1146,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
|
|||
if (!vectorType)
|
||||
return failure();
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
op.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
|
@ -2244,7 +1854,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
|
|||
if (!vectorType)
|
||||
return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
cmpiOp.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
CmpIOpAdaptor transformed(operands);
|
||||
|
@ -2282,7 +1892,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
|
|||
if (!vectorType)
|
||||
return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
|
||||
|
||||
return handleMultidimensionalVectors(
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
cmpfOp.getOperation(), operands, *getTypeConverter(),
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
CmpFOpAdaptor transformed(operands);
|
||||
|
@ -2445,7 +2055,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|||
// First insert it into an undef vector so we can shuffle it.
|
||||
auto loc = splatOp.getLoc();
|
||||
auto vectorTypeInfo =
|
||||
extractNDVectorTypeInfo(resultType, *getTypeConverter());
|
||||
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
|
||||
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
|
||||
auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
|
||||
if (!llvmNDVectorTy || !llvm1DVectorTy)
|
||||
|
|
Loading…
Reference in New Issue