[spirv] Add basic infrastructure for negative deserializer tests

We are relying on serializer to construct positive cases to drive
the test for deserializer. This leaves negative cases untested.

This CL adds a basic test fixture for covering the negative
corner cases to enforce a more robust deserializer.

Refactored common SPIR-V building methods out of serializer to
share it with the deserialization test.

PiperOrigin-RevId: 260742733
This commit is contained in:
Lei Zhang 2019-07-30 10:21:25 -07:00 committed by A. Unique TensorFlower
parent 4598c04dfe
commit 4a55bd5f28
8 changed files with 318 additions and 59 deletions

View File

@ -15,14 +15,15 @@
// limitations under the License.
// =============================================================================
//
// This file defines common utilities for SPIR-V binary module.
// This file declares common utilities for SPIR-V binary module.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_
#define MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_
#ifndef MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_
#define MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
@ -35,10 +36,17 @@ constexpr unsigned kHeaderWordCount = 5;
/// SPIR-V magic number
constexpr uint32_t kMagicNumber = 0x07230203;
/// The serializer tool ID registered to the Khronos Group
constexpr uint32_t kGeneratorNumber = 22;
/// Auto-generated getOpcode<*Op>() specializations
#define GET_SPIRV_SERIALIZATION_UTILS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
/// Appends a SPRI-V module header to `header` with the given `idBound`.
void appendModuleHeader(SmallVectorImpl<uint32_t> &header, uint32_t idBound);
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_SPIRV_SERIALIZATION_SPIRV_BINARY_UTILS_H_
#endif // MLIR_DIALECT_SPIRV_SPIRV_BINARY_UTILS_H_

View File

@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRVSerialization
ConvertToBinary.cpp
Deserializer.cpp
Serializer.cpp
SPIRVBinaryUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV

View File

@ -21,7 +21,7 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
@ -133,9 +133,12 @@ private:
Value *getValue(uint32_t id) { return valueMap.lookup(id); }
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively.
LogicalResult sliceInstruction(spirv::Opcode &opcode,
ArrayRef<uint32_t> &operands);
/// operands via `opcode` and `operands` respectively. Returns failure if
/// there is no more remaining instructions (`expectedOpcode` will be used to
/// compose the error message) or the next instruction is malformed.
LogicalResult
sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
Optional<spirv::Opcode> expectedOpcode = llvm::None);
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
/// This method is the main entrance for handling SPIR-V instruction; it
@ -216,11 +219,20 @@ LogicalResult Deserializer::deserialize() {
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
while (succeeded(sliceInstruction(opcode, operands))) {
auto binarySize = binary.size();
while (curOffset < binarySize) {
// Slice the next instruction out and populate `opcode` and `operands`.
// Interally this also updates `curOffset`.
if (failed(sliceInstruction(opcode, operands)))
return failure();
if (failed(processInstruction(opcode, operands)))
return failure();
}
assert(curOffset == binarySize &&
"deserializer should never index beyond the binary end");
for (auto &defered : deferedInstructions) {
if (failed(processInstruction(defered.first, defered.second, false))) {
return failure();
@ -324,7 +336,8 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
auto argType = functionType.getInput(i);
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
if (failed(sliceInstruction(opcode, operands))) {
if (failed(sliceInstruction(opcode, operands,
spirv::Opcode::OpFunctionParameter))) {
return failure();
}
if (opcode != spirv::Opcode::OpFunctionParameter) {
@ -361,19 +374,20 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
spirv::Opcode opcode;
ArrayRef<uint32_t> instOperands;
while (succeeded(sliceInstruction(opcode, instOperands)) &&
while (succeeded(sliceInstruction(opcode, instOperands,
spirv::Opcode::OpFunctionEnd)) &&
opcode != spirv::Opcode::OpFunctionEnd) {
if (failed(processInstruction(opcode, instOperands))) {
return failure();
}
}
std::swap(funcBody, opBuilder);
if (opcode != spirv::Opcode::OpFunctionEnd) {
return failure();
}
if (!instOperands.empty()) {
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
}
std::swap(funcBody, opBuilder);
return success();
}
@ -750,17 +764,22 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
// Instruction
//===----------------------------------------------------------------------===//
LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode,
ArrayRef<uint32_t> &operands) {
LogicalResult
Deserializer::sliceInstruction(spirv::Opcode &opcode,
ArrayRef<uint32_t> &operands,
Optional<spirv::Opcode> expectedOpcode) {
auto binarySize = binary.size();
if (curOffset >= binarySize) {
return failure();
return emitError(unknownLoc, "expected ")
<< (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
: "more")
<< " instruction";
}
// For each instruction, get its word count from the first word to slice it
// from the stream properly, and then dispatch to the instruction handler.
uint32_t wordCount = binary[curOffset] >> 16;
opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
if (wordCount == 0)
return emitError(unknownLoc, "word count cannot be zero");
@ -769,6 +788,7 @@ LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode,
if (nextOffset > binarySize)
return emitError(unknownLoc, "insufficient words for the last instruction");
opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
operands = binary.slice(curOffset + 1, wordCount - 1);
curOffset = nextOffset;
return success();

View File

@ -0,0 +1,53 @@
//===- SPIRVBinaryUtils.cpp - MLIR SPIR-V Binary Module Utilities ---------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines common utilities for SPIR-V binary module.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
using namespace mlir;
void spirv::appendModuleHeader(SmallVectorImpl<uint32_t> &header,
uint32_t idBound) {
// The major and minor version number for the generated SPIR-V binary.
// TODO(antiagainst): use target environment to select the version
constexpr uint8_t kMajorVersion = 1;
constexpr uint8_t kMinorVersion = 0;
// See "2.3. Physical Layout of a SPIR-V Module and Instruction" in the SPIR-V
// spec for the definition of the binary module header.
//
// The first five words of a SPIR-V module must be:
// +-------------------------------------------------------------------------+
// | Magic number |
// +-------------------------------------------------------------------------+
// | Version number (bytes: 0 | major number | minor number | 0) |
// +-------------------------------------------------------------------------+
// | Generator magic number |
// +-------------------------------------------------------------------------+
// | Bound (all result <id>s in the module guaranteed to be less than it) |
// +-------------------------------------------------------------------------+
// | 0 (reserved for instruction schema) |
// +-------------------------------------------------------------------------+
header.push_back(spirv::kMagicNumber);
header.push_back((kMajorVersion << 16) | (kMinorVersion << 8));
header.push_back(kGeneratorNumber);
header.push_back(idBound); // <id> bound
header.push_back(0); // Schema (reserved word)
}

View File

@ -21,7 +21,7 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
@ -116,9 +116,6 @@ private:
// Module structure
//===--------------------------------------------------------------------===//
/// Creates SPIR-V module header in the given `header`.
LogicalResult processHeader();
LogicalResult processMemoryModel();
LogicalResult processConstantOp(spirv::ConstantOp op);
@ -234,7 +231,6 @@ private:
// The following are for different SPIR-V instruction sections. They follow
// the logical layout of a SPIR-V module.
SmallVector<uint32_t, spirv::kHeaderWordCount> header;
SmallVector<uint32_t, 4> capabilities;
SmallVector<uint32_t, 0> extensions;
SmallVector<uint32_t, 0> extendedSets;
@ -282,17 +278,16 @@ LogicalResult Serializer::serialize() {
}
void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
auto moduleSize = header.size() + capabilities.size() + extensions.size() +
extendedSets.size() + memoryModel.size() +
entryPoints.size() + executionModes.size() +
decorations.size() + typesGlobalValues.size() +
functions.size();
auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
extensions.size() + extendedSets.size() +
memoryModel.size() + entryPoints.size() +
executionModes.size() + decorations.size() +
typesGlobalValues.size() + functions.size();
binary.clear();
binary.reserve(moduleSize);
processHeader();
binary.append(header.begin(), header.end());
spirv::appendModuleHeader(binary, nextID);
binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end());
@ -308,37 +303,6 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
// Module structure
//===----------------------------------------------------------------------===//
LogicalResult Serializer::processHeader() {
// The serializer tool ID registered to the Khronos Group
constexpr uint32_t kGeneratorNumber = 22;
// The major and minor version number for the generated SPIR-V binary.
// TODO(antiagainst): use target environment to select the version
constexpr uint8_t kMajorVersion = 1;
constexpr uint8_t kMinorVersion = 0;
// See "2.3. Physical Layout of a SPIR-V Module and Instruction" in the SPIR-V
// spec for the definition of the binary module header.
//
// The first five words of a SPIR-V module must be:
// +-------------------------------------------------------------------------+
// | Magic number |
// +-------------------------------------------------------------------------+
// | Version number (bytes: 0 | major number | minor number | 0) |
// +-------------------------------------------------------------------------+
// | Generator magic number |
// +-------------------------------------------------------------------------+
// | Bound (all result <id>s in the module guaranteed to be less than it) |
// +-------------------------------------------------------------------------+
// | 0 (reserved for instruction schema) |
// +-------------------------------------------------------------------------+
header.push_back(spirv::kMagicNumber);
header.push_back((kMajorVersion << 16) | (kMinorVersion << 8));
header.push_back(kGeneratorNumber);
header.push_back(nextID); // <id> bound
header.push_back(0); // Schema (reserved word)
return success();
}
LogicalResult Serializer::processMemoryModel() {
uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();

View File

@ -5,3 +5,5 @@ target_link_libraries(MLIRDialectTests
PRIVATE
MLIRIR
MLIRDialect)
add_subdirectory(SPIRV)

View File

@ -0,0 +1,8 @@
add_mlir_unittest(MLIRSPIRVTests
DeserializationTest.cpp
)
target_link_libraries(MLIRSPIRVTests
PRIVATE
MLIRSPIRV
MLIRSPIRVSerialization)

View File

@ -0,0 +1,203 @@
//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// The purpose of this file is to provide negative deserialization tests.
// For positive deserialization tests, please use serialization and
// deserialization for roundtripping.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "gmock/gmock.h"
#include <memory>
using namespace mlir;
using ::testing::StrEq;
//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//
/// A deserialization test fixture providing minimal SPIR-V building and
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
DeserializationTest() {
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
context.getDiagEngine().setHandler([&](Diagnostic diag) {
diagnostic.reset(new Diagnostic(std::move(diag)));
});
}
/// Performs deserialization and returns the constructed spv.module op.
Optional<spirv::ModuleOp> deserialize() {
return spirv::deserialize(binary, &context);
}
/// Checks there is a diagnostic generated with the given `errorMessage`.
void expectDiagnostic(StringRef errorMessage) {
ASSERT_NE(nullptr, diagnostic.get());
// TODO(antiagainst): check error location too.
EXPECT_THAT(diagnostic->str(), StrEq(errorMessage));
}
//===--------------------------------------------------------------------===//
// SPIR-V builder methods
//===--------------------------------------------------------------------===//
/// Adds the SPIR-V module header to `binary`.
void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
/// Adds the SPIR-V instruction into `binary`.
void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
assert(((wordCount >> 16) == 0) && "word count out of range!");
uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
binary.push_back(prefixedOpcode);
binary.append(operands.begin(), operands.end());
}
uint32_t addVoidType() {
auto id = nextID++;
addInstruction(spirv::Opcode::OpTypeVoid, {id});
return id;
}
uint32_t addIntType(uint32_t bitwidth) {
auto id = nextID++;
addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
return id;
}
uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
auto id = nextID++;
SmallVector<uint32_t, 4> operands;
operands.push_back(id);
operands.push_back(retType);
operands.append(paramTypes.begin(), paramTypes.end());
addInstruction(spirv::Opcode::OpTypeFunction, operands);
return id;
}
uint32_t addFunction(uint32_t retType, uint32_t fnType) {
auto id = nextID++;
addInstruction(spirv::Opcode::OpFunction,
{retType, id,
static_cast<uint32_t>(spirv::FunctionControl::None),
fnType});
return id;
}
uint32_t addFunctionEnd() {
auto id = nextID++;
addInstruction(spirv::Opcode::OpFunctionEnd, {id});
return id;
}
protected:
SmallVector<uint32_t, 5> binary;
uint32_t nextID = 1;
MLIRContext context;
std::unique_ptr<Diagnostic> diagnostic;
};
//===----------------------------------------------------------------------===//
// Basics
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, EmptyModuleFailure) {
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("SPIR-V binary module must have a 5-word header");
}
TEST_F(DeserializationTest, WrongMagicNumberFailure) {
addHeader();
binary.front() = 0xdeadbeef; // Change to a wrong magic number
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("incorrect magic number");
}
TEST_F(DeserializationTest, OnlyHeaderSuccess) {
addHeader();
EXPECT_NE(llvm::None, deserialize());
}
TEST_F(DeserializationTest, ZeroWordCountFailure) {
addHeader();
binary.push_back(0); // OpNop with zero word count
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("word count cannot be zero");
}
TEST_F(DeserializationTest, InsufficientWordFailure) {
addHeader();
binary.push_back((2u << 16) |
static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
// Missing word for type <id>
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("insufficient words for the last instruction");
}
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
addHeader();
addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}
//===----------------------------------------------------------------------===//
// Functions
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, FunctionMissingEndFailure) {
addHeader();
auto voidType = addVoidType();
auto fnType = addFunctionType(voidType, {});
addFunction(voidType, fnType);
// Missing OpFunctionEnd
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("expected OpFunctionEnd instruction");
}
TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
addHeader();
auto voidType = addVoidType();
auto i32Type = addIntType(32);
auto fnType = addFunctionType(voidType, {i32Type});
addFunction(voidType, fnType);
// Missing OpFunctionParameter
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("expected OpFunctionParameter instruction");
}