diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 87456f000edc..3742bab414ec 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -723,12 +723,6 @@ static LogicalResult verifyShiftOp(Operation *op) { //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - if (indices.empty()) { - emitError(baseLoc, "'spv.AccessChain' op expected at least " - "one index "); - return nullptr; - } - auto ptrType = type.dyn_cast(); if (!ptrType) { emitError(baseLoc, "'spv.AccessChain' op expected a pointer " @@ -791,19 +785,37 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, OpAsmParser::OperandType ptrInfo; SmallVector indicesInfo; Type type; - // TODO(denis0x0D): regarding to the spec an index must be any integer type, - // figure out how to use resolveOperand with a range of types and do not - // fail on first attempt. - Type indicesType = parser.getBuilder().getIntegerType(32); + auto loc = parser.getCurrentLocation(); + SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || - parser.resolveOperand(ptrInfo, type, state.operands) || - parser.resolveOperands(indicesInfo, indicesType, state.operands)) { + parser.resolveOperand(ptrInfo, type, state.operands)) { return failure(); } + // Check that the provided indices list is not empty before parsing their + // type list. + if (indicesInfo.empty()) { + return emitError(state.location, "'spv.AccessChain' op expected at " + "least one index "); + } + + if (parser.parseComma() || parser.parseTypeList(indicesTypes)) + return failure(); + + // Check that the indices types list is not empty and that it has a one-to-one + // mapping to the provided indices. + if (indicesTypes.size() != indicesInfo.size()) { + return emitError(state.location, "'spv.AccessChain' op indices " + "types' count must be equal to indices " + "info count"); + } + + if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) + return failure(); + auto resultType = getElementPtrType( type, llvm::makeArrayRef(state.operands).drop_front(), state.location); if (!resultType) { @@ -816,7 +828,8 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr() - << '[' << op.indices() << "] : " << op.base_ptr().getType(); + << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", " + << op.indices().getTypes(); } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir index 726b276010ef..43da9c5be429 100644 --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -11,7 +11,7 @@ module attributes {gpu.container_module} { %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> %2 = spv.constant 0 : i32 %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> - %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %5 = spv.Load "StorageBuffer" %4 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir index 0d14cbf9d3b6..68986741b32d 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir @@ -2,8 +2,8 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @array_stride(%arg0 : !spv.ptr, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr, stride=128>, StorageBuffer> - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr, stride=128>, StorageBuffer> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr, stride=128>, StorageBuffer>, i32, i32 + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr, stride=128>, StorageBuffer>, i32, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir index 0d58fea18a11..ad913dfb1624 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -95,8 +95,8 @@ spv.module Logical GLSL450 requires #spv.vce, Function>) -> !spv.ptr "None" { %0 = spv.constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 spv.ReturnValue %1 : !spv.ptr } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir index aa9653da4f2e..d83030d25298 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -58,7 +58,7 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @memory_accesses(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: loc({{".*debug.mlir"}}:61:10) - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer> + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer>, i32, i32 // CHECK: loc({{".*debug.mlir"}}:63:10) %3 = spv.Load "StorageBuffer" %2 : f32 // CHECK: loc({{.*debug.mlir"}}:65:5) diff --git a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir index faa371ea9016..0365fef03fbb 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir @@ -30,7 +30,7 @@ spv.module Logical GLSL450 requires #spv.vce { %0 = spv._address_of @globalInvocationID : !spv.ptr, Input> %1 = spv.constant 0: i32 // CHECK: spv.AccessChain %[[ADDR]] - %2 = spv.AccessChain %0[%1] : !spv.ptr, Input> + %2 = spv.AccessChain %0[%1] : !spv.ptr, Input>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir index 1b041a4aa604..d6e2090f02bb 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -65,9 +65,9 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @loop_kernel() "None" { %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0 : i32 - %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer>, i32 %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> - %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer>, i32 %6 = spv.constant 4 : i32 %7 = spv.constant 42 : i32 %8 = spv.constant 2 : i32 @@ -84,9 +84,9 @@ spv.module Logical GLSL450 requires #spv.vce { spv.BranchConditional %10, ^body, ^merge // CHECK-NEXT: ^bb2: // pred: ^bb1 ^body: - %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer> + %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer>, i32 %12 = spv.Load "StorageBuffer" %11 : f32 - %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer> + %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer>, i32 spv.Store "StorageBuffer" %13, %12 : f32 // CHECK: %[[ADD:.*]] = spv.IAdd %14 = spv.IAdd %9, %8 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir index 8dc90cb504e8..e10bfc88afb0 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -4,7 +4,7 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK-LABEL: @matrix_access_chain spv.func @matrix_access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32) -> !spv.ptr, Function> "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> - %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> + %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr>,Function>, i32 spv.ReturnValue %0 : !spv.ptr, Function> } @@ -20,6 +20,7 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>> + } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir index fbe45fa87d20..26584a479dec 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -18,8 +18,8 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, Function> + %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function>, i32 + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, Function>, i32, i32 spv.Return } } @@ -31,13 +31,13 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : f32 // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : f32 spv.Return } @@ -46,13 +46,13 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : i32 // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : i32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir index 6998930911db..d19812f48257 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir @@ -16,7 +16,7 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> %7 = spv.undef : !spv.ptr, StorageBuffer> %8 = spv.constant 0 : i32 - %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer> + %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir index 075ef3398d83..4e4bf06e6f73 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -102,14 +102,14 @@ spv.module Logical GLSL450 { %37 = spv.IAdd %arg4, %11 : i32 // CHECK: spv.AccessChain [[ARG0]] %c0 = spv.constant 0 : i32 - %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 %39 = spv.Load "StorageBuffer" %38 : f32 // CHECK: spv.AccessChain [[ARG1]] - %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 %41 = spv.Load "StorageBuffer" %40 : f32 %42 = spv.FAdd %39, %41 : f32 // CHECK: spv.AccessChain [[ARG2]] - %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 spv.Store "StorageBuffer" %43, %42 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir index 24b22c5dbf6f..8c9408ab089f 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -37,7 +37,7 @@ spv.module Logical GLSL450 { spv.func @callee() "None" { %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0: i32 - %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Branch ^next ^next: @@ -196,7 +196,7 @@ spv.module Logical GLSL450 { // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOADPTR]] %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> - %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer> + %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 %5 = spv.Load "StorageBuffer" %4 : i32 %6 = spv.SGreaterThan %5, %1 : i32 // CHECK: spv.selection @@ -204,7 +204,7 @@ spv.module Logical GLSL450 { spv.BranchConditional %6, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: [[STOREPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG1]] - %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer> + %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 // CHECK-NOT: spv.FunctionCall // CHECK: spv.AtomicIAdd "Device" "AcquireRelease" [[STOREPTR]], [[VAL]] // CHECK: spv.Branch diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir index 975012d3a26a..f54d1910be22 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -24,7 +24,7 @@ spv.module Logical GLSL450 { // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12]>, Uniform> %0 = spv._address_of @var0 : !spv.ptr, f32>, Uniform> // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12]>, Uniform> - %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform> + %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 20ed6e96be8d..2b719fd7219d 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -11,8 +11,8 @@ func @combine_full_access_chain() -> f32 { // CHECK-NEXT: spv.Load "Function" %[[PTR]] %c0 = spv.constant 0: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function>, i32, i32 %3 = spv.Load "Function" %2 : f32 spv.ReturnValue %3 : f32 } @@ -28,9 +28,9 @@ func @combine_access_chain_multi_use() -> !spv.array<4xf32> { // CHECK-NEXT: spv.Load "Function" %[[PTR_1]] %c0 = spv.constant 0: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function> - %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function>, i32 + %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xf32> %5 = spv.Load "Function" %3 : f32 spv.ReturnValue %4: !spv.array<4xf32> @@ -49,8 +49,8 @@ func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> { %c1 = spv.constant 1: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> - %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> + %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xi32> %5 = spv.Load "Function" %3 : !spv.array<4xi32> spv.ReturnValue %4 : !spv.array<4xi32> diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir index a2dafaddfa21..523bd6bfb030 100644 --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -97,8 +97,8 @@ spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : // CHECK-LABEL: @cooperative_matrix_access_chain spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { %0 = spv.constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 spv.ReturnValue %1 : !spv.ptr } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 14e1fa10735e..7dea7942d426 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -8,21 +8,21 @@ func @access_chain_struct() -> () { %0 = spv.constant 1: i32 %1 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 return } func @access_chain_1D_array(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function>, i32 return } func @access_chain_2D_array_1(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr>, Function>, i32, i32 %2 = spv.Load "Function" %1 ["Volatile"] : f32 return } @@ -30,7 +30,7 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () { func @access_chain_2D_array_2(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function>, i32 %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32> return } @@ -38,7 +38,7 @@ func @access_chain_2D_array_2(%arg0 : i32) -> () { func @access_chain_rtarray(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function>, i32 %2 = spv.Load "Function" %1 ["Volatile"] : f32 return } @@ -49,7 +49,7 @@ func @access_chain_non_composite() -> () { %0 = spv.constant 1: i32 %1 = spv.Variable : !spv.ptr // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}} - %2 = spv.AccessChain %1[%0] : !spv.ptr + %2 = spv.AccessChain %1[%0] : !spv.ptr, i32 return } @@ -58,7 +58,34 @@ func @access_chain_non_composite() -> () { func @access_chain_no_indices(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{expected at least one index}} - %1 = spv.AccessChain %0[] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[] : !spv.ptr>, Function>, i32 + return +} + +// ----- + +func @access_chain_missing_comma(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{expected ','}} + %1 = spv.AccessChain %0[%index0] : !spv.ptr>, Function> i32 + return +} + +// ----- + +func @access_chain_invalid_indices_types_count(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}} + %1 = spv.AccessChain %0[%index0] : !spv.ptr>, Function>, i32, i32 + return +} + +// ----- + +func @access_chain_missing_indices_type(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}} + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32 return } @@ -68,7 +95,7 @@ func @access_chain_invalid_type(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> // expected-error @+1 {{expected a pointer to composite type, but provided '!spv.array<4 x !spv.array<4 x f32>>'}} - %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>> + %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>, i32 return } @@ -77,7 +104,7 @@ func @access_chain_invalid_type(%index0 : i32) -> () { func @access_chain_invalid_index_1(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{expected SSA operand}} - %1 = spv.AccessChain %0[%index, 4] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index, 4] : !spv.ptr>, Function>, i32, i32 return } @@ -86,7 +113,7 @@ func @access_chain_invalid_index_1(%index0 : i32) -> () { func @access_chain_invalid_index_2(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 return } @@ -96,7 +123,7 @@ func @access_chain_invalid_constant_type_1() -> () { %0 = std.constant 1: i32 %1 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}} - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 return } @@ -106,7 +133,7 @@ func @access_chain_out_of_bounds() -> () { %index0 = "spv.constant"() { value = 12: i32} : () -> i32 %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct>'}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 return } @@ -115,7 +142,7 @@ func @access_chain_out_of_bounds() -> () { func @access_chain_invalid_accessing_type(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}} - %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr>, Function>, i32, i32, i32 return // ----- diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index e24ebcd6e656..93df070f0a2a 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -11,7 +11,7 @@ spv.module Logical GLSL450 { // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr>, Input> // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr>, Input> %1 = spv._address_of @var1 : !spv.ptr>, Input> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input>, i32, i32 spv.Return } }