forked from OSchip/llvm-project
[mlir][spirv] Add VectorInsertDynamicOp and vector.insertelement lowering
VectorInsertDynamicOp in SPIRV dialect conversion from vector.insertelement to spirv VectorInsertDynamicOp Differential Revision: https://reviews.llvm.org/D90927
This commit is contained in:
parent
539ce1d288
commit
3035e676a3
|
@ -3177,6 +3177,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
|||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
|
||||
def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
|
||||
def SPV_OC_OpVectorInsertDynamic : I32EnumAttrCase<"OpVectorInsertDynamic", 78>;
|
||||
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
|
||||
|
@ -3310,9 +3311,9 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
|
||||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
|
||||
SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
|
||||
SPV_OC_OpVectorInsertDynamic, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
|
||||
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
|
||||
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
|
||||
|
|
|
@ -171,7 +171,8 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
|
|||
// -----
|
||||
|
||||
def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
|
||||
[NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"type of 'result' matches element type of 'vector'",
|
||||
"vector", "result",
|
||||
"$_self.cast<mlir::VectorType>().getElementType()">]> {
|
||||
let summary = [{
|
||||
|
@ -225,4 +226,67 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"type of 'component' matches element type of 'vector'",
|
||||
"vector", "component",
|
||||
"$_self.cast<mlir::VectorType>().getElementType()">,
|
||||
AllTypesMatch<["vector", "result"]>]> {
|
||||
let summary = [{
|
||||
Make a copy of a vector, with a single, variably selected, component
|
||||
modified.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type must be an OpTypeVector.
|
||||
|
||||
Vector must have the same type as Result Type and is the vector that the
|
||||
non-written components are copied from.
|
||||
|
||||
Component is the value supplied for the component selected by Index. It
|
||||
must have the same type as the type of components in Result Type.
|
||||
|
||||
Index must be a scalar integer. It is interpreted as a 0-based index of
|
||||
which component to modify.
|
||||
|
||||
Behavior is undefined if Index's value is less than zero or greater than
|
||||
or equal to the number of components in Vector.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
scalar-type ::= integer-type | float-type | boolean-type
|
||||
vector-insert-dynamic-op ::= `spv.VectorInsertDynamic ` ssa-use `,`
|
||||
ssa-use `[` ssa-use `]`
|
||||
`:` `vector<` integer-literal `x` scalar-type `>` `,`
|
||||
integer-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
%scalar = ... : f32
|
||||
%2 = spv.VectorInsertDynamic %scalar %0[%1] : f32, vector<8xf32>, i32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_Vector:$vector,
|
||||
SPV_Scalar:$component,
|
||||
SPV_Integer:$index
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Vector:$result
|
||||
);
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$component `,` $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index)
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // SPIRV_COMPOSITE_OPS
|
||||
|
|
|
@ -97,14 +97,32 @@ struct VectorExtractElementOpConvert final
|
|||
}
|
||||
};
|
||||
|
||||
struct VectorInsertElementOpConvert final
|
||||
: public SPIRVOpLowering<vector::InsertElementOp> {
|
||||
using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::InsertElementOp insertElementOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
|
||||
return failure();
|
||||
vector::InsertElementOp::Adaptor adaptor(operands);
|
||||
Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
|
||||
insertElementOp.getLoc(), insertElementOp.getType(),
|
||||
insertElementOp.dest(), adaptor.source(), insertElementOp.position());
|
||||
rewriter.replaceOp(insertElementOp, newInsertElement);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
|
||||
SPIRVTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
|
||||
VectorInsertOpConvert, VectorExtractElementOpConvert>(
|
||||
context, typeConverter);
|
||||
VectorInsertOpConvert, VectorExtractElementOpConvert,
|
||||
VectorInsertElementOpConvert>(context, typeConverter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -39,3 +39,21 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
|
|||
%0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
|
||||
spv.ReturnValue %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: insert_element
|
||||
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
|
||||
// CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
|
||||
func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) {
|
||||
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
|
||||
spv.ReturnValue %0: vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
|
||||
// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
|
||||
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
|
||||
spv.Return
|
||||
}
|
||||
|
|
|
@ -16,4 +16,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
%0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
|
||||
spv.ReturnValue %0: f32
|
||||
}
|
||||
spv.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> "None" {
|
||||
// CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
|
||||
%0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
|
||||
spv.ReturnValue %0: vector<4xf32>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -273,3 +273,13 @@ func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 {
|
|||
%0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.VectorInsertDynamic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> {
|
||||
// CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
|
||||
%0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue