forked from OSchip/llvm-project
[spirv] Add Block decoration for spv.struct.
Add Block decoration for top-level spv.struct. Closes tensorflow/mlir#102 PiperOrigin-RevId: 265716241
This commit is contained in:
parent
2f59f76876
commit
8f2dfb51d4
|
@ -468,6 +468,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
|
|||
}
|
||||
typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
|
||||
break;
|
||||
case spirv::Decoration::Block:
|
||||
if (words.size() != 2) {
|
||||
return emitError(unknownLoc, "OpDecoration with ")
|
||||
<< decorationName << "needs a single target <id>";
|
||||
}
|
||||
// Block decoration does not affect spv.struct type.
|
||||
break;
|
||||
default:
|
||||
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
|
||||
}
|
||||
|
|
|
@ -174,6 +174,10 @@ private:
|
|||
|
||||
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
|
||||
|
||||
/// Returns true if the given type is a pointer type to a struct in Uniform or
|
||||
/// StorageBuffer storage class.
|
||||
bool isInterfaceStructPtrType(Type type) const;
|
||||
|
||||
/// Main dispatch method for serializing a type. The result <id> of the
|
||||
/// serialized type will be returned as `typeID`.
|
||||
LogicalResult processType(Location loc, Type type, uint32_t &typeID);
|
||||
|
@ -558,6 +562,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (isInterfaceStructPtrType(varOp.type())) {
|
||||
auto structType = varOp.type()
|
||||
.cast<spirv::PointerType>()
|
||||
.getPointeeType()
|
||||
.cast<spirv::StructType>();
|
||||
SmallVector<uint32_t, 2> args{
|
||||
findTypeID(structType),
|
||||
static_cast<uint32_t>(spirv::Decoration::Block)};
|
||||
if (failed(encodeInstructionInto(decorations, spirv::Opcode::OpDecorate,
|
||||
args))) {
|
||||
return varOp.emitError("cannot decorate ")
|
||||
<< structType << " with Block decoration";
|
||||
}
|
||||
}
|
||||
|
||||
elidedAttrs.push_back("type");
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
operands.push_back(resultTypeID);
|
||||
|
@ -609,6 +629,17 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
// Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool Serializer::isInterfaceStructPtrType(Type type) const {
|
||||
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
|
||||
auto storageClass = ptrType.getStorageClass();
|
||||
if (storageClass == spirv::StorageClass::Uniform ||
|
||||
storageClass == spirv::StorageClass::StorageBuffer) {
|
||||
return ptrType.getPointeeType().isa<spirv::StructType>();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processType(Location loc, Type type,
|
||||
uint32_t &typeID) {
|
||||
typeID = findTypeID(type);
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
add_mlir_unittest(MLIRSPIRVTests
|
||||
DeserializationTest.cpp
|
||||
SerializationTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRSPIRVTests
|
||||
PRIVATE
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVSerialization)
|
||||
|
||||
whole_archive_link(MLIRSPIRVTests MLIRSPIRV)
|
||||
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
//===- SerializationTest.cpp - SPIR-V Seserialization Tests -------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains corner case tests for the SPIR-V serializer that are not
|
||||
// covered by normal serialization and deserialization roundtripping.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/Serialization.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Fixture
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class SerializationTest : public ::testing::Test {
|
||||
protected:
|
||||
SerializationTest() { createModuleOp(); }
|
||||
|
||||
void createModuleOp() {
|
||||
Builder builder(&context);
|
||||
OperationState state(UnknownLoc::get(&context),
|
||||
spirv::ModuleOp::getOperationName());
|
||||
state.addAttribute("addressing_model",
|
||||
builder.getI32IntegerAttr(static_cast<uint32_t>(
|
||||
spirv::AddressingModel::Logical)));
|
||||
state.addAttribute("memory_model",
|
||||
builder.getI32IntegerAttr(
|
||||
static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
|
||||
spirv::ModuleOp::build(&builder, &state);
|
||||
module = cast<spirv::ModuleOp>(Operation::create(state));
|
||||
}
|
||||
|
||||
Type getFloatStructType() {
|
||||
OpBuilder opBuilder(module.body());
|
||||
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
|
||||
llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
|
||||
auto structType = spirv::StructType::get(elementTypes, layoutInfo);
|
||||
return structType;
|
||||
}
|
||||
|
||||
void addGlobalVar(Type type, llvm::StringRef name) {
|
||||
OpBuilder opBuilder(module.body());
|
||||
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
|
||||
opBuilder.create<spirv::GlobalVariableOp>(
|
||||
UnknownLoc::get(&context), opBuilder.getTypeAttr(ptrType),
|
||||
opBuilder.getStringAttr(name), nullptr);
|
||||
}
|
||||
|
||||
bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> operands)>
|
||||
matchFn) {
|
||||
auto binarySize = binary.size();
|
||||
auto begin = binary.begin();
|
||||
auto currOffset = spirv::kHeaderWordCount;
|
||||
|
||||
while (currOffset < binarySize) {
|
||||
auto wordCount = binary[currOffset] >> 16;
|
||||
if (!wordCount || (currOffset + wordCount > binarySize)) {
|
||||
return false;
|
||||
}
|
||||
spirv::Opcode opcode =
|
||||
static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
|
||||
|
||||
if (matchFn(opcode,
|
||||
llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
|
||||
begin + currOffset + wordCount))) {
|
||||
return true;
|
||||
}
|
||||
currOffset += wordCount;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected:
|
||||
MLIRContext context;
|
||||
spirv::ModuleOp module;
|
||||
SmallVector<uint32_t, 0> binary;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Block decoration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TEST_F(SerializationTest, BlockDecorationTest) {
|
||||
auto structType = getFloatStructType();
|
||||
addGlobalVar(structType, "var0");
|
||||
ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
|
||||
auto hasBlockDecoration = [](spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> operands) -> bool {
|
||||
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
|
||||
return false;
|
||||
return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
|
||||
};
|
||||
EXPECT_TRUE(findInstruction(hasBlockDecoration));
|
||||
}
|
Loading…
Reference in New Issue