[mlir][spirv] Add more vector conversion patterns

This patch introduces a few more straightforward patterns
to convert vector ops operating on 1-4 element vectors
to their corresponding SPIR-V counterparts.

This patch also enables converting vector<1xT> to T.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96042
This commit is contained in:
Lei Zhang 2021-02-05 09:03:48 -05:00
parent 4a64d8fe39
commit 9f622b3d5d
5 changed files with 180 additions and 19 deletions

View File

@ -19,10 +19,40 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include <numeric>
using namespace mlir;
/// Gets the first integer value from `attr`, assuming it is an integer array
/// attribute.
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
};
namespace {
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return failure();
vector::BitCastOp::Adaptor adaptor(operands);
if (dstType == adaptor.source().getType())
rewriter.replaceOp(bitcastOp, adaptor.source());
else
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
adaptor.source());
return success();
}
};
struct VectorBroadcastConvert final
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
@ -49,17 +79,58 @@ struct VectorExtractOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (extractOp.getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(extractOp.getVectorType()))
// Only support extracting a scalar value now.
VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
if (resultVectorType && resultVectorType.getNumElements() > 1)
return failure();
auto dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
int32_t id = getFirstIntValue(extractOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
return success();
}
};
struct VectorExtractStridedSliceOpConvert final
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
// Extract vector<1xT> not supported yet.
if (dstType.isa<spirv::ScalarType>())
return failure();
uint64_t offset = getFirstIntValue(extractOp.offsets());
uint64_t size = getFirstIntValue(extractOp.sizes());
uint64_t stride = getFirstIntValue(extractOp.strides());
if (stride != 1)
return failure();
Value srcVector = operands.front();
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
@ -86,7 +157,7 @@ struct VectorInsertOpConvert final
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
vector::InsertOp::Adaptor adaptor(operands);
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
int32_t id = getFirstIntValue(insertOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.source(), adaptor.dest(), id);
return success();
@ -129,13 +200,53 @@ struct VectorInsertElementOpConvert final
}
};
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp insertOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value srcVector = operands.front();
Value dstVector = operands.back();
// Insert scalar values not supported yet.
if (srcVector.getType().isa<spirv::ScalarType>() ||
dstVector.getType().isa<spirv::ScalarType>())
return failure();
uint64_t stride = getFirstIntValue(insertOp.strides());
if (stride != 1)
return failure();
uint64_t totalSize =
dstVector.getType().cast<VectorType>().getNumElements();
uint64_t insertSize =
srcVector.getType().cast<VectorType>().getNumElements();
uint64_t offset = getFirstIntValue(insertOp.offsets());
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
totalSize);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
insertOp, dstVector.getType(), dstVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
} // namespace
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
VectorExtractOpConvert, VectorFmaOpConvert,
VectorInsertOpConvert, VectorInsertElementOpConvert>(
typeConverter, context);
patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
VectorInsertStridedSliceOpConvert>(typeConverter, context);
}

View File

@ -269,12 +269,13 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
static Optional<Type>
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank() == 1 && type.getNumElements() == 1)
return type.getElementType();
if (!spirv::CompositeType::isValid(type)) {
// TODO: One-element vector types can be translated into scalar
// types. Vector types with more than four elements can be translated into
// TODO: Vector types with more than four elements can be translated into
// array types.
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: 1- and > 4-element unimplemented\n");
LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
return llvm::None;
}

View File

@ -117,9 +117,9 @@ func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) {
return
}
// CHECK-LABEL: @unsupported_1elem_vector
func @unsupported_1elem_vector(%arg0: vector<1xi32>) {
// CHECK: addi
// CHECK-LABEL: @one_elem_vector
func @one_elem_vector(%arg0: vector<1xi32>) {
// CHECK: spv.IAdd %{{.+}}, %{{.+}}: i32
%0 = addi %arg0, %arg0: vector<1xi32>
return
}

View File

@ -203,18 +203,19 @@ func @float_vector(
%arg1: vector<3xf64>
) { return }
// CHECK-LABEL: spv.func @one_element_vector
// CHECK-SAME: %{{.+}}: i32
func @one_element_vector(%arg0: vector<1xi32>) { return }
} // end module
// -----
// Check that 1- or > 4-element vectors are not supported.
// Check that > 4-element vectors are not supported.
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-NOT: spv.func @one_element_vector
func @one_element_vector(%arg0: vector<1xi32>) { return }
// CHECK-NOT: spv.func @large_vector
func @large_vector(%arg0: vector<1024xi32>) { return }

View File

@ -1,5 +1,21 @@
// RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
// CHECK-LABEL: func @bitcast
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16>
// CHECK: %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
// CHECK: %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) {
%0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
%1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
spv.Return
}
} // end module
// -----
// CHECK-LABEL: broadcast
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
@ -12,6 +28,18 @@ func @broadcast(%arg0 : f32) {
// -----
// CHECK-LABEL: func @extract
// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
func @extract(%arg0 : vector<2xf32>) {
%0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
%1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
spv.Return
}
// -----
// CHECK-LABEL: extract_insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@ -42,6 +70,16 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
// -----
// CHECK-LABEL: func @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
func @extract_strided_slice(%arg0: vector<4xf32>) {
%0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
spv.Return
}
// -----
// CHECK-LABEL: insert_element
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
@ -60,6 +98,16 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
// -----
// CHECK-LABEL: func @insert_strided_slice
// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
// CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) {
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
spv.Return
}
// -----
// CHECK-LABEL: func @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>