forked from OSchip/llvm-project
[mlir][spirv] Support size-1 vector inserts during conversion
Differential Revision: https://reviews.llvm.org/D115517
This commit is contained in:
parent
11754a4dbb
commit
4710750854
|
@ -157,6 +157,13 @@ struct VectorInsertOpConvert final
|
|||
LogicalResult
|
||||
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Special case for inserting scalar values into size-1 vectors.
|
||||
if (insertOp.getSourceType().isIntOrFloat() &&
|
||||
insertOp.getDestVectorType().getNumElements() == 1) {
|
||||
rewriter.replaceOp(insertOp, adaptor.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
if (insertOp.getSourceType().isa<VectorType>() ||
|
||||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
||||
return failure();
|
||||
|
@ -209,20 +216,23 @@ struct VectorInsertStridedSliceOpConvert final
|
|||
Value srcVector = adaptor.getOperands().front();
|
||||
Value dstVector = adaptor.getOperands().back();
|
||||
|
||||
// Insert scalar values not supported yet.
|
||||
if (srcVector.getType().isa<spirv::ScalarType>() ||
|
||||
dstVector.getType().isa<spirv::ScalarType>())
|
||||
return failure();
|
||||
|
||||
uint64_t stride = getFirstIntValue(insertOp.strides());
|
||||
if (stride != 1)
|
||||
return failure();
|
||||
uint64_t offset = getFirstIntValue(insertOp.offsets());
|
||||
|
||||
if (srcVector.getType().isa<spirv::ScalarType>()) {
|
||||
assert(!dstVector.getType().isa<spirv::ScalarType>());
|
||||
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
||||
insertOp, dstVector.getType(), srcVector, dstVector,
|
||||
rewriter.getI32ArrayAttr(offset));
|
||||
return success();
|
||||
}
|
||||
|
||||
uint64_t totalSize =
|
||||
dstVector.getType().cast<VectorType>().getNumElements();
|
||||
uint64_t insertSize =
|
||||
srcVector.getType().cast<VectorType>().getNumElements();
|
||||
uint64_t offset = getFirstIntValue(insertOp.offsets());
|
||||
|
||||
SmallVector<int32_t, 2> indices(totalSize);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
|
|
|
@ -61,6 +61,17 @@ func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_size1_vector
|
||||
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
|
||||
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
|
||||
// CHECK: return %[[R]]
|
||||
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
|
||||
%1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
|
||||
return %1 : vector<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_element
|
||||
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
|
||||
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
|
||||
|
@ -139,6 +150,17 @@ func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_size1_vector
|
||||
// CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32>
|
||||
// CHECK: %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]]
|
||||
// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
|
||||
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
|
||||
%1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
|
||||
return %1 : vector<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fma
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
|
||||
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
|
||||
|
|
Loading…
Reference in New Issue