diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 48044e9de37d..d300725bd480 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -468,6 +468,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef words) { } typeDecorations[words[0]] = static_cast(words[2]); break; + case spirv::Decoration::Block: + if (words.size() != 2) { + return emitError(unknownLoc, "OpDecoration with ") + << decorationName << "needs a single target "; + } + // Block decoration does not affect spv.struct type. + break; default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 3f1b01372c92..03973db0d953 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -174,6 +174,10 @@ private: bool isVoidType(Type type) const { return type.isa(); } + /// Returns true if the given type is a pointer type to a struct in Uniform or + /// StorageBuffer storage class. + bool isInterfaceStructPtrType(Type type) const; + /// Main dispatch method for serializing a type. The result of the /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); @@ -558,6 +562,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { return failure(); } + + if (isInterfaceStructPtrType(varOp.type())) { + auto structType = varOp.type() + .cast() + .getPointeeType() + .cast(); + SmallVector args{ + findTypeID(structType), + static_cast(spirv::Decoration::Block)}; + if (failed(encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, + args))) { + return varOp.emitError("cannot decorate ") + << structType << " with Block decoration"; + } + } + elidedAttrs.push_back("type"); SmallVector operands; operands.push_back(resultTypeID); @@ -609,6 +629,17 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { // Type //===----------------------------------------------------------------------===// +bool Serializer::isInterfaceStructPtrType(Type type) const { + if (auto ptrType = type.dyn_cast()) { + auto storageClass = ptrType.getStorageClass(); + if (storageClass == spirv::StorageClass::Uniform || + storageClass == spirv::StorageClass::StorageBuffer) { + return ptrType.getPointeeType().isa(); + } + } + return false; +} + LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { typeID = findTypeID(type); diff --git a/mlir/unittests/Dialect/SPIRV/CMakeLists.txt b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt index 4e851601f270..b444b5c0220a 100644 --- a/mlir/unittests/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt @@ -1,8 +1,11 @@ add_mlir_unittest(MLIRSPIRVTests DeserializationTest.cpp + SerializationTest.cpp ) target_link_libraries(MLIRSPIRVTests PRIVATE MLIRSPIRV MLIRSPIRVSerialization) +whole_archive_link(MLIRSPIRVTests MLIRSPIRV) + diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp new file mode 100644 index 000000000000..65758a7bee7e --- /dev/null +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -0,0 +1,124 @@ +//===- SerializationTest.cpp - SPIR-V Seserialization Tests -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains corner case tests for the SPIR-V serializer that are not +// covered by normal serialization and deserialization roundtripping. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "gmock/gmock.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class SerializationTest : public ::testing::Test { +protected: + SerializationTest() { createModuleOp(); } + + void createModuleOp() { + Builder builder(&context); + OperationState state(UnknownLoc::get(&context), + spirv::ModuleOp::getOperationName()); + state.addAttribute("addressing_model", + builder.getI32IntegerAttr(static_cast( + spirv::AddressingModel::Logical))); + state.addAttribute("memory_model", + builder.getI32IntegerAttr( + static_cast(spirv::MemoryModel::GLSL450))); + spirv::ModuleOp::build(&builder, &state); + module = cast(Operation::create(state)); + } + + Type getFloatStructType() { + OpBuilder opBuilder(module.body()); + llvm::SmallVector elementTypes{opBuilder.getF32Type()}; + llvm::SmallVector layoutInfo{0}; + auto structType = spirv::StructType::get(elementTypes, layoutInfo); + return structType; + } + + void addGlobalVar(Type type, llvm::StringRef name) { + OpBuilder opBuilder(module.body()); + auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); + opBuilder.create( + UnknownLoc::get(&context), opBuilder.getTypeAttr(ptrType), + opBuilder.getStringAttr(name), nullptr); + } + + bool findInstruction(llvm::function_ref operands)> + matchFn) { + auto binarySize = binary.size(); + auto begin = binary.begin(); + auto currOffset = spirv::kHeaderWordCount; + + while (currOffset < binarySize) { + auto wordCount = binary[currOffset] >> 16; + if (!wordCount || (currOffset + wordCount > binarySize)) { + return false; + } + spirv::Opcode opcode = + static_cast(binary[currOffset] & 0xffff); + + if (matchFn(opcode, + llvm::ArrayRef(begin + currOffset + 1, + begin + currOffset + wordCount))) { + return true; + } + currOffset += wordCount; + } + return false; + } + +protected: + MLIRContext context; + spirv::ModuleOp module; + SmallVector binary; +}; + +//===----------------------------------------------------------------------===// +// Block decoration +//===----------------------------------------------------------------------===// + +TEST_F(SerializationTest, BlockDecorationTest) { + auto structType = getFloatStructType(); + addGlobalVar(structType, "var0"); + ASSERT_TRUE(succeeded(spirv::serialize(module, binary))); + auto hasBlockDecoration = [](spirv::Opcode opcode, + ArrayRef operands) -> bool { + if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) + return false; + return operands[1] == static_cast(spirv::Decoration::Block); + }; + EXPECT_TRUE(findInstruction(hasBlockDecoration)); +}