[mlir][spirv] Fix spv.CompositeConstruct assembly and validation

This commit fixes spv.CompositeConstruct to assembly to list
operand types to enable vector construction out of smaller vectors.
Validation is also fixed to properly check the cases for vector
construction.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130669
This commit is contained in:
Lei Zhang 2022-07-27 19:16:56 -04:00
parent bfdca1535c
commit 7668e58210
8 changed files with 116 additions and 88 deletions

View File

@ -64,6 +64,10 @@ def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> {
let results = (outs
SPV_Composite:$result
);
let assemblyFormat = [{
$constituents attr-dict `:` `(` type(operands) `)` `->` type($result)
}];
}
// -----

View File

@ -31,6 +31,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"
#include <numeric>
using namespace mlir;
@ -1618,66 +1619,64 @@ LogicalResult spirv::BranchConditionalOp::verify() {
// spv.CompositeConstruct
//===----------------------------------------------------------------------===//
ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser,
OperationState &state) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
Type type;
auto loc = parser.getCurrentLocation();
if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
return failure();
}
auto cType = type.dyn_cast<spirv::CompositeType>();
if (!cType) {
return parser.emitError(
loc, "result type must be a composite type, but provided ")
<< type;
}
if (cType.hasCompileTimeKnownNumElements() &&
operands.size() != cType.getNumElements()) {
return parser.emitError(loc, "has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided " << operands.size();
}
// TODO: Add support for constructing a vector type from the vector operands.
// According to the spec: "for constructing a vector, the operands may
// also be vectors with the same component type as the Result Type component
// type".
SmallVector<Type, 4> elementTypes;
elementTypes.reserve(operands.size());
for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
elementTypes.push_back(cType.getElementType(index));
}
state.addTypes(type);
return parser.resolveOperands(operands, elementTypes, loc, state.operands);
}
void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) {
printer << " " << constituents() << " : " << getResult().getType();
}
LogicalResult spirv::CompositeConstructOp::verify() {
auto cType = getType().cast<spirv::CompositeType>();
operand_range constituents = this->constituents();
if (cType.isa<spirv::CooperativeMatrixNVType>()) {
if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
if (constituents.size() != 1)
return emitError("has incorrect number of operands: expected ")
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
} else if (constituents.size() != cType.getNumElements()) {
return emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "
<< constituents.size();
if (coopType.getElementType() != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
<< coopType.getElementType() << ", but provided "
<< constituents.front().getType();
return success();
}
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
return emitError("operand type mismatch: expected operand type ")
return emitOpError("operand type mismatch: expected operand type ")
<< cType.getElementType(index) << ", but provided "
<< constituents[index].getType();
}
}
return success();
}
// If not constructing a cooperative matrix type, then we must be constructing
// a vector type.
auto resultType = cType.dyn_cast<VectorType>();
if (!resultType)
return emitOpError(
"expected to return a vector or cooperative matrix when the number of "
"constituents is less than what the result needs");
SmallVector<unsigned> sizes;
for (Value component : constituents) {
if (!component.getType().isa<VectorType>() &&
!component.getType().isIntOrFloat())
return emitOpError("operand type mismatch: expected operand to have "
"a scalar or vector type, but provided ")
<< component.getType();
Type elementType = component.getType();
if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
sizes.push_back(vectorType.getNumElements());
elementType = vectorType.getElementType();
} else {
sizes.push_back(1);
}
if (elementType != resultType.getElementType())
return emitOpError("operand element type mismatch: expected to be ")
<< resultType.getElementType() << ", but provided " << elementType;
}
unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
if (totalCount != cType.getNumElements())
return emitOpError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided " << totalCount;
return success();
}

View File

@ -32,8 +32,8 @@ func.func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vect
// CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
// CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16
// CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16
// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]]
// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]]
// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>

View File

@ -18,8 +18,8 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
// CHECK-LABEL: @broadcast
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32>
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
// CHECK: spv.CompositeConstruct %[[A]], %[[A]]
func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
%0 = vector.broadcast %arg0 : f32 to vector<4xf32>
%1 = vector.broadcast %arg0 : f32 to vector<2xf32>
@ -182,7 +182,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
// CHECK-LABEL: func @splat
// CHECK-SAME: (%[[A:.+]]: f32)
// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
// CHECK: return %[[VAL]]
func.func @splat(%f : f32) -> vector<4xf32> {
%splat = vector.splat %f : vector<4xf32>
@ -206,7 +206,7 @@ func.func @splat_size1_vector(%f : f32) -> vector<1xf32> {
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32>
// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (f32, f32, f32, f32) -> vector<4xf32>
func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
%shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32>
return %shuffle : vector<4xf32>

View File

@ -5,48 +5,41 @@
//===----------------------------------------------------------------------===//
func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
return %0: vector<3xf32>
}
// -----
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> {
// CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4 x f32>, !spv.struct<(f32)>)>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
// CHECK: spv.CompositeConstruct
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
return %0: !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)>
}
// -----
// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
// CHECK: spv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
%0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
// CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
%0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup>
// CHECK: spv.CompositeConstruct {{%.*}} : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
%0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
return %0: !spv.coopmatrix<8x16xf32, Subgroup>
}
// -----
func.func @composite_construct_empty_struct() -> !spv.struct<()> {
// CHECK: spv.CompositeConstruct : !spv.struct<()>
%0 = spv.CompositeConstruct : !spv.struct<()>
return %0: !spv.struct<()>
}
// -----
func.func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 {
// expected-error @+1 {{result type must be a composite type, but provided 'f32'}}
%0 = spv.CompositeConstruct %arg0 : f32
return %0: f32
}
// -----
func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
// expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
%0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32>
%0 = spv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32>
return %0: vector<3xf32>
}
@ -54,20 +47,52 @@ func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2
func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
// expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
%0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32>
return %0: vector<3xi32>
}
// -----
func.func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
%0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup>
%0 = spv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spv.coopmatrix<8x16xf32, Subgroup>
return %0: !spv.coopmatrix<8x16xf32, Subgroup>
}
// -----
func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
%0 = spv.CompositeConstruct %arg0 : (i32) -> !spv.coopmatrix<8x16xf32, Subgroup>
return %0: !spv.coopmatrix<8x16xf32, Subgroup>
}
// -----
func.func @composite_construct_array(%arg0: f32) -> !spv.array<4xf32> {
// expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
%0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.array<4xf32>
return %0: !spv.array<4xf32>
}
// -----
func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> {
// expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}}
%0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
// expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}}
%0 = spv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
return %0: vector<4xf32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//

View File

@ -3,26 +3,26 @@
spv.module Logical GLSL450 {
spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" {
%0 = spv.Undef : vector<3xf32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
%1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32>
%2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32>
%3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32>
%4 = spv.Undef : !spv.array<4xf32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32, f32) -> !spv.array<4 x f32>
%5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32>
%6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32>
%7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32>
%8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32>
%9 = spv.Undef : !spv.struct<(f32, i32, f32)>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<(f32, i32, f32)>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, i32, f32) -> !spv.struct<(f32, i32, f32)>
%10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<(f32, i32, f32)>
%11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<(f32, i32, f32)>
%12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<(f32, i32, f32)>
%13 = spv.Undef : !spv.struct<(f32, !spv.array<3xf32>)>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<(f32, !spv.array<3 x f32>)>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : (f32, !spv.array<3 x f32>) -> !spv.struct<(f32, !spv.array<3 x f32>)>
%14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<(f32, !spv.array<3xf32>)>
%15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<(f32, !spv.array<3xf32>)>

View File

@ -7,8 +7,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.ReturnValue %0: !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>
}
spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" {
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
spv.ReturnValue %0: vector<3xf32>
}
spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {

View File

@ -33,7 +33,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: loc({{".*debug.mlir"}}:34:10)
%0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)>
// CHECK: loc({{".*debug.mlir"}}:36:10)
%1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32>
%1 = spv.CompositeConstruct %arg2, %arg3 : (f32, f32) -> vector<2xf32>
spv.Return
}