forked from OSchip/llvm-project
[mlir][spirv] Support size-1 vector inserts during conversion
Differential Revision: https://reviews.llvm.org/D115517
This commit is contained in:
parent
11754a4dbb
commit
4710750854
|
@ -157,6 +157,13 @@ struct VectorInsertOpConvert final
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// Special case for inserting scalar values into size-1 vectors.
|
||||||
|
if (insertOp.getSourceType().isIntOrFloat() &&
|
||||||
|
insertOp.getDestVectorType().getNumElements() == 1) {
|
||||||
|
rewriter.replaceOp(insertOp, adaptor.source());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
if (insertOp.getSourceType().isa<VectorType>() ||
|
if (insertOp.getSourceType().isa<VectorType>() ||
|
||||||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -209,20 +216,23 @@ struct VectorInsertStridedSliceOpConvert final
|
||||||
Value srcVector = adaptor.getOperands().front();
|
Value srcVector = adaptor.getOperands().front();
|
||||||
Value dstVector = adaptor.getOperands().back();
|
Value dstVector = adaptor.getOperands().back();
|
||||||
|
|
||||||
// Insert scalar values not supported yet.
|
|
||||||
if (srcVector.getType().isa<spirv::ScalarType>() ||
|
|
||||||
dstVector.getType().isa<spirv::ScalarType>())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
uint64_t stride = getFirstIntValue(insertOp.strides());
|
uint64_t stride = getFirstIntValue(insertOp.strides());
|
||||||
if (stride != 1)
|
if (stride != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
uint64_t offset = getFirstIntValue(insertOp.offsets());
|
||||||
|
|
||||||
|
if (srcVector.getType().isa<spirv::ScalarType>()) {
|
||||||
|
assert(!dstVector.getType().isa<spirv::ScalarType>());
|
||||||
|
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
||||||
|
insertOp, dstVector.getType(), srcVector, dstVector,
|
||||||
|
rewriter.getI32ArrayAttr(offset));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t totalSize =
|
uint64_t totalSize =
|
||||||
dstVector.getType().cast<VectorType>().getNumElements();
|
dstVector.getType().cast<VectorType>().getNumElements();
|
||||||
uint64_t insertSize =
|
uint64_t insertSize =
|
||||||
srcVector.getType().cast<VectorType>().getNumElements();
|
srcVector.getType().cast<VectorType>().getNumElements();
|
||||||
uint64_t offset = getFirstIntValue(insertOp.offsets());
|
|
||||||
|
|
||||||
SmallVector<int32_t, 2> indices(totalSize);
|
SmallVector<int32_t, 2> indices(totalSize);
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
|
|
@ -61,6 +61,17 @@ func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @insert_size1_vector
|
||||||
|
// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
|
||||||
|
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
|
||||||
|
// CHECK: return %[[R]]
|
||||||
|
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
|
||||||
|
%1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
|
||||||
|
return %1 : vector<1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @extract_element
|
// CHECK-LABEL: @extract_element
|
||||||
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
|
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
|
||||||
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
|
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
|
||||||
|
@ -139,6 +150,17 @@ func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @insert_size1_vector
|
||||||
|
// CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32>
|
||||||
|
// CHECK: %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]]
|
||||||
|
// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
|
||||||
|
func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
|
||||||
|
%1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
|
||||||
|
return %1 : vector<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @fma
|
// CHECK-LABEL: @fma
|
||||||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
|
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
|
||||||
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
|
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
|
||||||
|
|
Loading…
Reference in New Issue