diff --git a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h new file mode 100644 index 000000000000..e21aa431f297 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h @@ -0,0 +1,80 @@ +//===-- LayoutUtils.h - Decorate composite type with layout information ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines utilities used to get alignment and layout information for +// types in SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ +#define MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ + +#include + +namespace mlir { +class Type; +class VectorType; +namespace spirv { +class StructType; +class ArrayType; +} // namespace spirv + +/// According to the Vulkan spec "14.5.4. Offset and Stride Assignment": +/// "There are different alignment requirements depending on the specific +/// resources and on the features enabled on the device." +/// +/// There are 3 types of alignment: scalar, base, extended. +/// See the spec for details. +/// +/// Note: Even if scalar alignment is supported, it is generally more +/// performant to use the base alignment. So here the calculation is based on +/// base alignment. +/// +/// The memory layout must obey the following rules: +/// 1. The Offset decoration of any member must be a multiple of its alignment. +/// 2. Any ArrayStride or MatrixStride decoration must be a multiple of the +/// alignment of the array or matrix as defined above. +/// +/// According to the SPIR-V spec: +/// "The ArrayStride, MatrixStride, and Offset decorations must be large +/// enough to hold the size of the objects they affect (that is, specifying +/// overlap is invalid)." +class VulkanLayoutUtils { +public: + using Size = uint64_t; + + /// Returns a new type with layout info. Assigns the type size in bytes to the + /// `size`. Assigns the type alignment in bytes to the `alignment`. + static Type decorateType(spirv::StructType structType, Size &size, + Size &alignment); + /// Checks whether a type is legal in terms of Vulkan layout info + /// decoration. A type is dynamically illegal if it's a composite type in the + /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage + /// Classes without layout informtation. + static bool isLegalType(Type type); + +private: + static Type decorateType(Type type, Size &size, Size &alignment); + static Type decorateType(VectorType vectorType, Size &size, Size &alignment); + static Type decorateType(spirv::ArrayType arrayType, Size &size, + Size &alignment); + /// Calculates the alignment for the given scalar type. + static Size getScalarTypeAlignment(Type scalarType); +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_LAYOUTUTILS_H_ diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 4031385a8f12..45720fefcc11 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -131,7 +131,12 @@ void GPUToSPIRVPass::runOnModule() { builder.getI32IntegerAttr( static_cast(spirv::AddressingModel::Logical)), builder.getI32IntegerAttr( - static_cast(spirv::MemoryModel::GLSL450))); + static_cast(spirv::MemoryModel::GLSL450)), + builder.getStrArrayAttr( + spirv::stringifyCapability(spirv::Capability::Shader)), + builder.getStrArrayAttr(spirv::stringifyExtension( + spirv::Extension::SPV_KHR_storage_buffer_storage_class))); + // Hardwire the capability to be Shader. OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0)); moduleBuilder.clone(*funcOp.getOperation()); spirvModules.push_back(spvModule); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 0b5790fd06ec..a0906b75950d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/Ops.h" @@ -83,18 +84,25 @@ Type SPIRVBasicTypeConverter::convertType(Type t) { // Entry Function signature Conversion //===----------------------------------------------------------------------===// +Type getLayoutDecoratedType(spirv::StructType type) { + VulkanLayoutUtils::Size size = 0, alignment = 0; + return VulkanLayoutUtils::decorateType(type, size, alignment); +} + /// Generates the type of variable given the type of object. static Type getGlobalVarTypeForEntryFnArg(Type t) { auto convertedType = basicTypeConversion(t); if (auto ptrType = convertedType.dyn_cast()) { if (!ptrType.getPointeeType().isa()) { return spirv::PointerType::get( - spirv::StructType::get(ptrType.getPointeeType()), + getLayoutDecoratedType( + spirv::StructType::get(ptrType.getPointeeType())), ptrType.getStorageClass()); } } else { - return spirv::PointerType::get(spirv::StructType::get(convertedType), - spirv::StorageClass::StorageBuffer); + return spirv::PointerType::get( + getLayoutDecoratedType(spirv::StructType::get(convertedType)), + spirv::StorageClass::StorageBuffer); } return convertedType; } @@ -119,7 +127,7 @@ static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter, spirv::GlobalVariableOp var; { OpBuilder::InsertionGuard moduleInsertionGuard(rewriter); - rewriter.setInsertionPointToStart(&module.getBlock()); + rewriter.setInsertionPoint(funcOp.getOperation()); std::string varName = funcOp.getName().str() + "_arg_" + std::to_string(origArgNum); var = rewriter.create( diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index 04bdc73f22a5..a6a36f4f11c7 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV SPIRVDialect.cpp SPIRVOps.cpp SPIRVTypes.cpp + LayoutUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp new file mode 100644 index 000000000000..eee01f155e91 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -0,0 +1,165 @@ +//===-- LayoutUtils.cpp - Decorate composite type with layout information -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements Utilities used to get alignment and layout information +// for types in SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/LayoutUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" + +using namespace mlir; + +Type VulkanLayoutUtils::decorateType(spirv::StructType structType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + if (structType.getNumElements() == 0) { + return structType; + } + + llvm::SmallVector memberTypes; + llvm::SmallVector layoutInfo; + llvm::SmallVector + memberDecorations; + + VulkanLayoutUtils::Size structMemberOffset = 0; + VulkanLayoutUtils::Size maxMemberAlignment = 1; + + for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { + VulkanLayoutUtils::Size memberSize = 0; + VulkanLayoutUtils::Size memberAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType( + structType.getElementType(i), memberSize, memberAlignment); + structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); + memberTypes.push_back(memberType); + layoutInfo.push_back(structMemberOffset); + // According to the Vulkan spec: + // "A structure has a base alignment equal to the largest base alignment of + // any of its members." + structMemberOffset += memberSize; + maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment); + } + + // According to the Vulkan spec: + // "The Offset decoration of a member must not place it between the end of a + // structure or an array and the next multiple of the alignment of that + // structure or array." + size = llvm::alignTo(structMemberOffset, maxMemberAlignment); + alignment = maxMemberAlignment; + structType.getMemberDecorations(memberDecorations); + return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations); +} + +Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + if (spirv::SPIRVDialect::isValidScalarType(type)) { + alignment = VulkanLayoutUtils::getScalarTypeAlignment(type); + // Vulkan spec does not specify any padding for a scalar type. + size = alignment; + return type; + } + + switch (type.getKind()) { + case spirv::TypeKind::Struct: + return VulkanLayoutUtils::decorateType(type.cast(), size, + alignment); + case spirv::TypeKind::Array: + return VulkanLayoutUtils::decorateType(type.cast(), size, + alignment); + case StandardTypes::Vector: + return VulkanLayoutUtils::decorateType(type.cast(), size, + alignment); + default: + llvm_unreachable("unhandled SPIR-V type"); + } +} + +Type VulkanLayoutUtils::decorateType(VectorType vectorType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + const auto numElements = vectorType.getNumElements(); + auto elementType = vectorType.getElementType(); + VulkanLayoutUtils::Size elementSize = 0; + VulkanLayoutUtils::Size elementAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, + elementAlignment); + // According to the Vulkan spec: + // 1. "A two-component vector has a base alignment equal to twice its scalar + // alignment." + // 2. "A three- or four-component vector has a base alignment equal to four + // times its scalar alignment." + size = elementSize * numElements; + alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; + return VectorType::get(numElements, memberType); +} + +Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, + VulkanLayoutUtils::Size &size, + VulkanLayoutUtils::Size &alignment) { + const auto numElements = arrayType.getNumElements(); + auto elementType = arrayType.getElementType(); + spirv::ArrayType::LayoutInfo elementSize = 0; + VulkanLayoutUtils::Size elementAlignment = 1; + + auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, + elementAlignment); + // According to the Vulkan spec: + // "An array has a base alignment equal to the base alignment of its element + // type." + size = elementSize * numElements; + alignment = elementAlignment; + return spirv::ArrayType::get(memberType, numElements, elementSize); +} + +VulkanLayoutUtils::Size +VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { + // According to the Vulkan spec: + // 1. "A scalar of size N has a scalar alignment of N." + // 2. "A scalar has a base alignment equal to its scalar alignment." + // 3. "A scalar, vector or matrix type has an extended alignment equal to its + // base alignment." + auto bitWidth = scalarType.getIntOrFloatBitWidth(); + if (bitWidth == 1) + return 1; + return bitWidth / 8; +} + +bool VulkanLayoutUtils::isLegalType(Type type) { + auto ptrType = type.dyn_cast(); + if (!ptrType) { + return true; + } + + auto storageClass = ptrType.getStorageClass(); + auto structType = ptrType.getPointeeType().dyn_cast(); + if (!structType) { + return true; + } + + switch (storageClass) { + case spirv::StorageClass::Uniform: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::PushConstant: + case spirv::StorageClass::PhysicalStorageBuffer: + return structType.hasLayout() || !structType.getNumElements(); + default: + return true; + } +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index d2693f279456..a854a1d511c2 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -22,6 +22,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" @@ -29,196 +30,6 @@ using namespace mlir; -/// According to the Vulkan spec "14.5.4. Offset and Stride Assignment": -/// "There are different alignment requirements depending on the specific -/// resources and on the features enabled on the device." -/// -/// There are 3 types of alignment: scalar, base, extended. -/// See the spec for details. -/// -/// Note: Even if scalar alignment is supported, it is generally more -/// performant to use the base alignment. So here the calculation is based on -/// base alignment. -/// -/// The memory layout must obey the following rules: -/// 1. The Offset decoration of any member must be a multiple of its alignment. -/// 2. Any ArrayStride or MatrixStride decoration must be a multiple of the -/// alignment of the array or matrix as defined above. -/// -/// According to the SPIR-V spec: -/// "The ArrayStride, MatrixStride, and Offset decorations must be large -/// enough to hold the size of the objects they affect (that is, specifying -/// overlap is invalid)." -namespace { -class VulkanLayoutUtils { -public: - using Alignment = uint64_t; - - /// Returns a new type with layout info. Assigns the type size in bytes to the - /// `size`. Assigns the type alignment in bytes to the `alignment`. - static Type decorateType(spirv::StructType structType, - spirv::StructType::LayoutInfo &size, - Alignment &alignment); - /// Checks whether a type is legal in terms of Vulkan layout info - /// decoration. A type is dynamically illegal if it's a composite type in the - /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage - /// Classes without layout informtation. - static bool isLegalType(Type type); - -private: - static Type decorateType(Type type, spirv::StructType::LayoutInfo &size, - Alignment &alignment); - static Type decorateType(VectorType vectorType, - spirv::StructType::LayoutInfo &size, - Alignment &alignment); - static Type decorateType(spirv::ArrayType arrayType, - spirv::StructType::LayoutInfo &size, - Alignment &alignment); - /// Calculates the alignment for the given scalar type. - static Alignment getScalarTypeAlignment(Type scalarType); -}; - -Type VulkanLayoutUtils::decorateType(spirv::StructType structType, - spirv::StructType::LayoutInfo &size, - VulkanLayoutUtils::Alignment &alignment) { - if (structType.getNumElements() == 0) { - return structType; - } - - llvm::SmallVector memberTypes; - llvm::SmallVector layoutInfo; - llvm::SmallVector - memberDecorations; - - spirv::StructType::LayoutInfo structMemberOffset = 0; - VulkanLayoutUtils::Alignment maxMemberAlignment = 1; - - for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { - spirv::StructType::LayoutInfo memberSize = 0; - VulkanLayoutUtils::Alignment memberAlignment = 1; - - auto memberType = VulkanLayoutUtils::decorateType( - structType.getElementType(i), memberSize, memberAlignment); - structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); - memberTypes.push_back(memberType); - layoutInfo.push_back(structMemberOffset); - // According to the Vulkan spec: - // "A structure has a base alignment equal to the largest base alignment of - // any of its members." - structMemberOffset += memberSize; - maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment); - } - - // According to the Vulkan spec: - // "The Offset decoration of a member must not place it between the end of a - // structure or an array and the next multiple of the alignment of that - // structure or array." - size = llvm::alignTo(structMemberOffset, maxMemberAlignment); - alignment = maxMemberAlignment; - structType.getMemberDecorations(memberDecorations); - return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations); -} - -Type VulkanLayoutUtils::decorateType(Type type, - spirv::StructType::LayoutInfo &size, - VulkanLayoutUtils::Alignment &alignment) { - if (spirv::SPIRVDialect::isValidScalarType(type)) { - alignment = VulkanLayoutUtils::getScalarTypeAlignment(type); - // Vulkan spec does not specify any padding for a scalar type. - size = alignment; - return type; - } - - switch (type.getKind()) { - case spirv::TypeKind::Struct: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); - case spirv::TypeKind::Array: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); - case StandardTypes::Vector: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); - default: - llvm_unreachable("unhandled SPIR-V type"); - } -} - -Type VulkanLayoutUtils::decorateType(VectorType vectorType, - spirv::StructType::LayoutInfo &size, - VulkanLayoutUtils::Alignment &alignment) { - const auto numElements = vectorType.getNumElements(); - auto elementType = vectorType.getElementType(); - spirv::StructType::LayoutInfo elementSize = 0; - VulkanLayoutUtils::Alignment elementAlignment = 1; - - auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, - elementAlignment); - // According to the Vulkan spec: - // 1. "A two-component vector has a base alignment equal to twice its scalar - // alignment." - // 2. "A three- or four-component vector has a base alignment equal to four - // times its scalar alignment." - size = elementSize * numElements; - alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; - return VectorType::get(numElements, memberType); -} - -Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, - spirv::StructType::LayoutInfo &size, - VulkanLayoutUtils::Alignment &alignment) { - const auto numElements = arrayType.getNumElements(); - auto elementType = arrayType.getElementType(); - spirv::ArrayType::LayoutInfo elementSize = 0; - VulkanLayoutUtils::Alignment elementAlignment = 1; - - auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, - elementAlignment); - // According to the Vulkan spec: - // "An array has a base alignment equal to the base alignment of its element - // type." - size = elementSize * numElements; - alignment = elementAlignment; - return spirv::ArrayType::get(memberType, numElements, elementSize); -} - -VulkanLayoutUtils::Alignment -VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { - // According to the Vulkan spec: - // 1. "A scalar of size N has a scalar alignment of N." - // 2. "A scalar has a base alignment equal to its scalar alignment." - // 3. "A scalar, vector or matrix type has an extended alignment equal to its - // base alignment." - auto bitWidth = scalarType.getIntOrFloatBitWidth(); - if (bitWidth == 1) - return 1; - return bitWidth / 8; -} - -bool VulkanLayoutUtils::isLegalType(Type type) { - auto ptrType = type.dyn_cast(); - if (!ptrType) { - return true; - } - - auto storageClass = ptrType.getStorageClass(); - auto structType = ptrType.getPointeeType().dyn_cast(); - if (!structType) { - return true; - } - - switch (storageClass) { - case spirv::StorageClass::Uniform: - case spirv::StorageClass::StorageBuffer: - case spirv::StorageClass::PushConstant: - case spirv::StorageClass::PhysicalStorageBuffer: - return structType.hasLayout() || !structType.getNumElements(); - default: - return true; - } -} -} // namespace - namespace { class SPIRVGlobalVariableOpLayoutInfoDecoration : public OpRewritePattern { @@ -228,7 +39,7 @@ public: PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op, PatternRewriter &rewriter) const override { spirv::StructType::LayoutInfo structSize = 0; - VulkanLayoutUtils::Alignment structAlignment = 1; + VulkanLayoutUtils::Size structAlignment = 1; SmallVector globalVarAttrs; auto ptrType = op.type().cast(); diff --git a/mlir/test/Conversion/GPUToSPIRV/load_store.mlir b/mlir/test/Conversion/GPUToSPIRV/load_store.mlir index fc3f12d0e3c5..daa975becaf1 100644 --- a/mlir/test/Conversion/GPUToSPIRV/load_store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load_store.mlir @@ -21,13 +21,13 @@ module attributes {gpu.container_module} { // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input> - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr>>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr>>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr>>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr [16]> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [16]> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr [16]> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> // CHECK: func [[FN:@.*]]() func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) attributes {gpu.kernel} { diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir index e1642ea0815f..61bc6eaedf75 100644 --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -4,8 +4,8 @@ module attributes {gpu.container_module} { module @kernels attributes {gpu.kernel_module} { // CHECK: spv.module "Logical" "GLSL450" { - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> // CHECK: func [[FN:@.*]]() func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>) attributes { gpu.kernel } {