forked from OSchip/llvm-project
[spirv] Extend spv.array with Layoutinfo
Extend spv.array with Layoutinfo to support (de)serialization. Closes tensorflow/mlir#80 PiperOrigin-RevId: 263795304
This commit is contained in:
parent
9c29273ddc
commit
cf358017e6
|
@ -73,14 +73,23 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
|
|||
detail::ArrayTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
// Zero layout specifies that is no layout
|
||||
using LayoutInfo = uint64_t;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
|
||||
|
||||
static ArrayType get(Type elementType, unsigned elementCount);
|
||||
|
||||
static ArrayType get(Type elementType, unsigned elementCount,
|
||||
LayoutInfo layoutInfo);
|
||||
|
||||
unsigned getNumElements() const;
|
||||
|
||||
Type getElementType() const;
|
||||
|
||||
bool hasLayout() const;
|
||||
|
||||
uint64_t getArrayStride() const;
|
||||
};
|
||||
|
||||
// SPIR-V image type
|
||||
|
|
|
@ -53,6 +53,18 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
|
|||
// Type Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Forward declarations.
|
||||
template <typename ValTy>
|
||||
static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
|
||||
StringRef spec);
|
||||
template <>
|
||||
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
|
||||
StringRef spec);
|
||||
|
||||
template <>
|
||||
Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
|
||||
StringRef spec);
|
||||
|
||||
// Parses "<number> x" from the beginning of `spec`.
|
||||
static bool parseNumberX(StringRef &spec, int64_t &number) {
|
||||
spec = spec.ltrim();
|
||||
|
@ -150,7 +162,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
|
|||
// | vector-type
|
||||
// | spirv-type
|
||||
//
|
||||
// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
|
||||
// array-type ::= `!spv.array<` integer-literal `x` element-type
|
||||
// (`[` integer-literal `]`)? `>`
|
||||
static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
|
||||
Location loc) {
|
||||
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
|
||||
|
@ -171,11 +184,37 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
|
|||
return Type();
|
||||
}
|
||||
|
||||
ArrayType::LayoutInfo layoutInfo = 0;
|
||||
size_t lastLSquare;
|
||||
|
||||
// Handle case when element type is not a trivial type
|
||||
auto lastRDelimiter = spec.rfind('>');
|
||||
if (lastRDelimiter != StringRef::npos) {
|
||||
lastLSquare = spec.find('[', lastRDelimiter);
|
||||
} else {
|
||||
lastLSquare = spec.rfind('[');
|
||||
}
|
||||
|
||||
if (lastLSquare != StringRef::npos) {
|
||||
auto layoutSpec = spec.substr(lastLSquare);
|
||||
auto layout =
|
||||
parseAndVerify<ArrayType::LayoutInfo>(dialect, loc, layoutSpec);
|
||||
if (!layout) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
if (!(layoutInfo = layout.getValue())) {
|
||||
emitError(loc, "ArrayStride must be greater than zero");
|
||||
return Type();
|
||||
}
|
||||
spec = spec.substr(0, lastLSquare);
|
||||
}
|
||||
|
||||
Type elementType = parseAndVerifyType(dialect, spec, loc);
|
||||
if (!elementType)
|
||||
return Type();
|
||||
|
||||
return ArrayType::get(elementType, count);
|
||||
return ArrayType::get(elementType, count, layoutInfo);
|
||||
}
|
||||
|
||||
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
|
||||
|
@ -267,18 +306,17 @@ Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
|
|||
}
|
||||
|
||||
template <>
|
||||
Optional<spirv::StructType::LayoutInfo>
|
||||
parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
|
||||
Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc,
|
||||
StringRef spec) {
|
||||
uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
|
||||
if (!spec.consume_front("[")) {
|
||||
emitError(loc, "expected '[' while parsing layout specification in '")
|
||||
<< spec << "'";
|
||||
return llvm::None;
|
||||
}
|
||||
spec = spec.trim();
|
||||
if (spec.consumeInteger(10, offsetVal)) {
|
||||
emitError(
|
||||
loc,
|
||||
"expected unsigned integer to specify offset of member in struct: '")
|
||||
emitError(loc, "expected unsigned integer to specify layout information: '")
|
||||
<< spec << "'";
|
||||
return llvm::None;
|
||||
}
|
||||
|
@ -292,7 +330,7 @@ parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
|
|||
<< spec << "'";
|
||||
return llvm::None;
|
||||
}
|
||||
return spirv::StructType::LayoutInfo{offsetVal};
|
||||
return offsetVal;
|
||||
}
|
||||
|
||||
// Functor object to parse a comma separated list of specs. The function
|
||||
|
@ -530,8 +568,11 @@ Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(ArrayType type, llvm::raw_ostream &os) {
|
||||
os << "array<" << type.getNumElements() << " x " << type.getElementType()
|
||||
<< ">";
|
||||
os << "array<" << type.getNumElements() << " x " << type.getElementType();
|
||||
if (type.hasLayout()) {
|
||||
os << " [" << type.getArrayStride() << "]";
|
||||
}
|
||||
os << ">";
|
||||
}
|
||||
|
||||
static void print(RuntimeArrayType type, llvm::raw_ostream &os) {
|
||||
|
|
|
@ -34,7 +34,7 @@ using namespace mlir::spirv;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct spirv::detail::ArrayTypeStorage : public TypeStorage {
|
||||
using KeyTy = std::pair<Type, unsigned>;
|
||||
using KeyTy = std::tuple<Type, unsigned, ArrayType::LayoutInfo>;
|
||||
|
||||
static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
|
@ -42,18 +42,26 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
|
|||
}
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(elementType, getSubclassData());
|
||||
return key == KeyTy(elementType, getSubclassData(), layoutInfo);
|
||||
}
|
||||
|
||||
ArrayTypeStorage(const KeyTy &key)
|
||||
: TypeStorage(key.second), elementType(key.first) {}
|
||||
: TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
|
||||
layoutInfo(std::get<2>(key)) {}
|
||||
|
||||
Type elementType;
|
||||
ArrayType::LayoutInfo layoutInfo;
|
||||
};
|
||||
|
||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
|
||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
||||
elementCount);
|
||||
elementCount, 0);
|
||||
}
|
||||
|
||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
|
||||
ArrayType::LayoutInfo layoutInfo) {
|
||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
||||
elementCount, layoutInfo);
|
||||
}
|
||||
|
||||
unsigned ArrayType::getNumElements() const {
|
||||
|
@ -62,6 +70,11 @@ unsigned ArrayType::getNumElements() const {
|
|||
|
||||
Type ArrayType::getElementType() const { return getImpl()->elementType; }
|
||||
|
||||
// ArrayStride must be greater than zero
|
||||
bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
|
||||
|
||||
uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompositeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -207,6 +207,9 @@ private:
|
|||
// Result <id> to decorations mapping.
|
||||
DenseMap<uint32_t, NamedAttributeList> decorations;
|
||||
|
||||
// Result <id> to type decorations.
|
||||
DenseMap<uint32_t, uint32_t> typeDecorations;
|
||||
|
||||
// List of instructions that are processed in a defered fashion (after an
|
||||
// initial processing of the entire binary). Some operations like
|
||||
// OpEntryPoint, and OpExecutionMode use forward references to function
|
||||
|
@ -330,6 +333,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
|
|||
opBuilder.getStringAttr(stringifyBuiltIn(
|
||||
static_cast<spirv::BuiltIn>(words[2]))));
|
||||
break;
|
||||
case spirv::Decoration::ArrayStride:
|
||||
if (words.size() != 3) {
|
||||
return emitError(unknownLoc, "OpDecorate with ")
|
||||
<< decorationName << " needs a single integer literal";
|
||||
}
|
||||
typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
|
||||
break;
|
||||
default:
|
||||
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
|
||||
}
|
||||
|
@ -590,7 +600,8 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
|
|||
<< defOp->getName();
|
||||
}
|
||||
|
||||
typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count);
|
||||
typeMap[operands[0]] = spirv::ArrayType::get(
|
||||
elementTy, count, typeDecorations.lookup(operands[0]));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -132,6 +132,12 @@ private:
|
|||
LogicalResult processDecoration(Location loc, uint32_t resultID,
|
||||
NamedAttribute attr);
|
||||
|
||||
template <typename DType>
|
||||
LogicalResult processTypeDecoration(Location loc, DType type,
|
||||
uint32_t resultId) {
|
||||
return emitError(loc, "unhandled decoraion for type:") << type;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Types
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -148,7 +154,7 @@ private:
|
|||
|
||||
/// Method for preparing basic SPIR-V type serialization. Returns the type's
|
||||
/// opcode and operands for the instruction via `typeEnum` and `operands`.
|
||||
LogicalResult prepareBasicType(Location loc, Type type,
|
||||
LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
|
||||
spirv::Opcode &typeEnum,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
|
@ -366,6 +372,22 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
|
|||
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
|
||||
Location loc, spirv::ArrayType type, uint32_t resultID) {
|
||||
if (type.hasLayout()) {
|
||||
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
|
||||
SmallVector<uint32_t, 3> args;
|
||||
args.push_back(resultID);
|
||||
args.push_back(static_cast<uint32_t>(spirv::Decoration::ArrayStride));
|
||||
args.push_back(type.getArrayStride());
|
||||
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LogicalResult Serializer::processFuncOp(FuncOp op) {
|
||||
uint32_t fnTypeID = 0;
|
||||
// Generate type of the function.
|
||||
|
@ -445,7 +467,7 @@ LogicalResult Serializer::processType(Location loc, Type type,
|
|||
if ((type.isa<FunctionType>() &&
|
||||
succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
|
||||
operands))) ||
|
||||
succeeded(prepareBasicType(loc, type, typeEnum, operands))) {
|
||||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) {
|
||||
typeIDMap[type] = typeID;
|
||||
return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
|
||||
}
|
||||
|
@ -453,7 +475,8 @@ LogicalResult Serializer::processType(Location loc, Type type,
|
|||
}
|
||||
|
||||
LogicalResult
|
||||
Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
|
||||
Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
||||
spirv::Opcode &typeEnum,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
if (isVoidType(type)) {
|
||||
typeEnum = spirv::Opcode::OpTypeVoid;
|
||||
|
@ -501,9 +524,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
|
|||
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
|
||||
/*isSpec=*/false)) {
|
||||
operands.push_back(elementCountID);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
return processTypeDecoration(loc, arrayType, resultID);
|
||||
}
|
||||
|
||||
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
|
||||
|
||||
func @spirvmodule() {
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>,
|
||||
%arg1 : i32, %arg2 : i32) {
|
||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32 [4]> [128]>, StorageBuffer>
|
||||
%2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -12,6 +12,9 @@ func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
|
|||
// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
|
||||
func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
|
||||
|
||||
// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>)
|
||||
func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{spv.array delimiter <...> mismatch}}
|
||||
|
@ -74,6 +77,11 @@ func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{ArrayStride must be greater than zero}}
|
||||
func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PointerType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -246,5 +254,5 @@ func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
|
||||
// expected-error @+1 {{expected unsigned integer to specify layout info}}
|
||||
func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()
|
||||
|
|
Loading…
Reference in New Issue