[mlir][spirv] Support size-1 vector inserts during conversion

Differential Revision: https://reviews.llvm.org/D115517
This commit is contained in:
Lei Zhang 2022-01-07 17:19:41 -05:00
parent 11754a4dbb
commit 4710750854
2 changed files with 38 additions and 6 deletions

View File

@ -157,6 +157,13 @@ struct VectorInsertOpConvert final
LogicalResult LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
rewriter.replaceOp(insertOp, adaptor.source());
return success();
}
if (insertOp.getSourceType().isa<VectorType>() || if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType())) !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure(); return failure();
@ -209,20 +216,23 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front(); Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back(); Value dstVector = adaptor.getOperands().back();
// Insert scalar values not supported yet.
if (srcVector.getType().isa<spirv::ScalarType>() ||
dstVector.getType().isa<spirv::ScalarType>())
return failure();
uint64_t stride = getFirstIntValue(insertOp.strides()); uint64_t stride = getFirstIntValue(insertOp.strides());
if (stride != 1) if (stride != 1)
return failure(); return failure();
uint64_t offset = getFirstIntValue(insertOp.offsets());
if (srcVector.getType().isa<spirv::ScalarType>()) {
assert(!dstVector.getType().isa<spirv::ScalarType>());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, dstVector.getType(), srcVector, dstVector,
rewriter.getI32ArrayAttr(offset));
return success();
}
uint64_t totalSize = uint64_t totalSize =
dstVector.getType().cast<VectorType>().getNumElements(); dstVector.getType().cast<VectorType>().getNumElements();
uint64_t insertSize = uint64_t insertSize =
srcVector.getType().cast<VectorType>().getNumElements(); srcVector.getType().cast<VectorType>().getNumElements();
uint64_t offset = getFirstIntValue(insertOp.offsets());
SmallVector<int32_t, 2> indices(totalSize); SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin(), indices.end(), 0);

View File

@ -61,6 +61,17 @@ func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// ----- // -----
// CHECK-LABEL: @insert_size1_vector
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
// CHECK: return %[[R]]
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
%1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
return %1 : vector<1xf32>
}
// -----
// CHECK-LABEL: @extract_element // CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 // CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
@ -139,6 +150,17 @@ func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector
// ----- // -----
// CHECK-LABEL: @insert_size1_vector
// CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32>
// CHECK: %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]]
// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
%1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
return %1 : vector<3xf32>
}
// -----
// CHECK-LABEL: @fma // CHECK-LABEL: @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> // CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>