[spirv] Extend spv.array with Layoutinfo

Extend spv.array with Layoutinfo to support (de)serialization.

Closes tensorflow/mlir#80

PiperOrigin-RevId: 263795304
This commit is contained in:
Denis Khalikov 2019-08-16 10:17:47 -07:00 committed by A. Unique TensorFlower
parent 9c29273ddc
commit cf358017e6
7 changed files with 138 additions and 21 deletions

View File

@ -73,14 +73,23 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;
// Zero layout specifies that is no layout
using LayoutInfo = uint64_t;
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
static ArrayType get(Type elementType, unsigned elementCount);
static ArrayType get(Type elementType, unsigned elementCount,
LayoutInfo layoutInfo);
unsigned getNumElements() const;
Type getElementType() const;
bool hasLayout() const;
uint64_t getArrayStride() const;
};
// SPIR-V image type

View File

@ -53,6 +53,18 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
// Type Parsing
//===----------------------------------------------------------------------===//
// Forward declarations.
template <typename ValTy>
static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
StringRef spec);
template <>
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
StringRef spec);
template <>
Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
StringRef spec);
// Parses "<number> x" from the beginning of `spec`.
static bool parseNumberX(StringRef &spec, int64_t &number) {
spec = spec.ltrim();
@ -150,7 +162,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
// | vector-type
// | spirv-type
//
// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
// array-type ::= `!spv.array<` integer-literal `x` element-type
// (`[` integer-literal `]`)? `>`
static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
@ -171,11 +184,37 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
return Type();
}
ArrayType::LayoutInfo layoutInfo = 0;
size_t lastLSquare;
// Handle case when element type is not a trivial type
auto lastRDelimiter = spec.rfind('>');
if (lastRDelimiter != StringRef::npos) {
lastLSquare = spec.find('[', lastRDelimiter);
} else {
lastLSquare = spec.rfind('[');
}
if (lastLSquare != StringRef::npos) {
auto layoutSpec = spec.substr(lastLSquare);
auto layout =
parseAndVerify<ArrayType::LayoutInfo>(dialect, loc, layoutSpec);
if (!layout) {
return Type();
}
if (!(layoutInfo = layout.getValue())) {
emitError(loc, "ArrayStride must be greater than zero");
return Type();
}
spec = spec.substr(0, lastLSquare);
}
Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
return ArrayType::get(elementType, count);
return ArrayType::get(elementType, count, layoutInfo);
}
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
@ -267,18 +306,17 @@ Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
}
template <>
Optional<spirv::StructType::LayoutInfo>
parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
if (!spec.consume_front("[")) {
emitError(loc, "expected '[' while parsing layout specification in '")
<< spec << "'";
return llvm::None;
}
spec = spec.trim();
if (spec.consumeInteger(10, offsetVal)) {
emitError(
loc,
"expected unsigned integer to specify offset of member in struct: '")
emitError(loc, "expected unsigned integer to specify layout information: '")
<< spec << "'";
return llvm::None;
}
@ -292,7 +330,7 @@ parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
<< spec << "'";
return llvm::None;
}
return spirv::StructType::LayoutInfo{offsetVal};
return offsetVal;
}
// Functor object to parse a comma separated list of specs. The function
@ -530,8 +568,11 @@ Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
//===----------------------------------------------------------------------===//
static void print(ArrayType type, llvm::raw_ostream &os) {
os << "array<" << type.getNumElements() << " x " << type.getElementType()
<< ">";
os << "array<" << type.getNumElements() << " x " << type.getElementType();
if (type.hasLayout()) {
os << " [" << type.getArrayStride() << "]";
}
os << ">";
}
static void print(RuntimeArrayType type, llvm::raw_ostream &os) {

View File

@ -34,7 +34,7 @@ using namespace mlir::spirv;
//===----------------------------------------------------------------------===//
struct spirv::detail::ArrayTypeStorage : public TypeStorage {
using KeyTy = std::pair<Type, unsigned>;
using KeyTy = std::tuple<Type, unsigned, ArrayType::LayoutInfo>;
static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@ -42,18 +42,26 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, getSubclassData());
return key == KeyTy(elementType, getSubclassData(), layoutInfo);
}
ArrayTypeStorage(const KeyTy &key)
: TypeStorage(key.second), elementType(key.first) {}
: TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
layoutInfo(std::get<2>(key)) {}
Type elementType;
ArrayType::LayoutInfo layoutInfo;
};
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
elementCount);
elementCount, 0);
}
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
ArrayType::LayoutInfo layoutInfo) {
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
elementCount, layoutInfo);
}
unsigned ArrayType::getNumElements() const {
@ -62,6 +70,11 @@ unsigned ArrayType::getNumElements() const {
Type ArrayType::getElementType() const { return getImpl()->elementType; }
// ArrayStride must be greater than zero
bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//

View File

@ -207,6 +207,9 @@ private:
// Result <id> to decorations mapping.
DenseMap<uint32_t, NamedAttributeList> decorations;
// Result <id> to type decorations.
DenseMap<uint32_t, uint32_t> typeDecorations;
// List of instructions that are processed in a defered fashion (after an
// initial processing of the entire binary). Some operations like
// OpEntryPoint, and OpExecutionMode use forward references to function
@ -330,6 +333,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
opBuilder.getStringAttr(stringifyBuiltIn(
static_cast<spirv::BuiltIn>(words[2]))));
break;
case spirv::Decoration::ArrayStride:
if (words.size() != 3) {
return emitError(unknownLoc, "OpDecorate with ")
<< decorationName << " needs a single integer literal";
}
typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
@ -590,7 +600,8 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
<< defOp->getName();
}
typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count);
typeMap[operands[0]] = spirv::ArrayType::get(
elementTy, count, typeDecorations.lookup(operands[0]));
return success();
}

View File

@ -132,6 +132,12 @@ private:
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
template <typename DType>
LogicalResult processTypeDecoration(Location loc, DType type,
uint32_t resultId) {
return emitError(loc, "unhandled decoraion for type:") << type;
}
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
@ -148,7 +154,7 @@ private:
/// Method for preparing basic SPIR-V type serialization. Returns the type's
/// opcode and operands for the instruction via `typeEnum` and `operands`.
LogicalResult prepareBasicType(Location loc, Type type,
LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
@ -366,6 +372,22 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
Location loc, spirv::ArrayType type, uint32_t resultID) {
if (type.hasLayout()) {
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
SmallVector<uint32_t, 3> args;
args.push_back(resultID);
args.push_back(static_cast<uint32_t>(spirv::Decoration::ArrayStride));
args.push_back(type.getArrayStride());
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
return success();
}
} // namespace
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function.
@ -445,7 +467,7 @@ LogicalResult Serializer::processType(Location loc, Type type,
if ((type.isa<FunctionType>() &&
succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
operands))) ||
succeeded(prepareBasicType(loc, type, typeEnum, operands))) {
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) {
typeIDMap[type] = typeID;
return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
}
@ -453,7 +475,8 @@ LogicalResult Serializer::processType(Location loc, Type type,
}
LogicalResult
Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
@ -501,9 +524,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
/*isSpec=*/false)) {
operands.push_back(elementCountID);
return success();
}
return failure();
return processTypeDecoration(loc, arrayType, resultID);
}
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {

View File

@ -0,0 +1,13 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
func @spirvmodule() {
spv.module "Logical" "VulkanKHR" {
func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>,
%arg1 : i32, %arg2 : i32) {
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32 [4]> [128]>, StorageBuffer>
%2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>
spv.Return
}
}
return
}

View File

@ -12,6 +12,9 @@ func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>)
func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> ()
// -----
// expected-error @+1 {{spv.array delimiter <...> mismatch}}
@ -74,6 +77,11 @@ func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
// -----
// expected-error @+1 {{ArrayStride must be greater than zero}}
func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
// -----
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
@ -246,5 +254,5 @@ func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
// -----
// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
// expected-error @+1 {{expected unsigned integer to specify layout info}}
func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()