[mlir] Add bytecode encodings for the builtin ElementsAttr attributes

This adds bytecode support for DenseArrayAttr, DenseIntOrFpElementsAttr,
DenseStringElementsAttr, and SparseElementsAttr.

Differential Revision: https://reviews.llvm.org/D133744
This commit is contained in:
River Riddle 2022-09-12 23:22:26 -07:00
parent 9e0900cbf1
commit 5fb1bbe6d4
6 changed files with 220 additions and 3 deletions

View File

@ -143,6 +143,9 @@ public:
/// Read a string from the bytecode.
virtual LogicalResult readString(StringRef &result) = 0;
/// Read a blob from the bytecode.
virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
private:
/// Read a handle to a dialect resource.
virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
@ -225,6 +228,11 @@ public:
/// only be called if such a guarantee can be made, such as when the string is
/// owned by an attribute or type.
virtual void writeOwnedString(StringRef str) = 0;
/// Write a blob to the bytecode, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process. The blob is
/// written as-is, with no additional compression or compaction.
virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
};
//===----------------------------------------------------------------------===//

View File

@ -887,6 +887,17 @@ public:
return stringReader.parseString(reader, result);
}
LogicalResult readBlob(ArrayRef<char> &result) override {
uint64_t dataSize;
ArrayRef<uint8_t> data;
if (failed(reader.parseVarInt(dataSize)) ||
failed(reader.parseBytes(dataSize, data)))
return failure();
result = llvm::makeArrayRef(reinterpret_cast<const char *>(data.data()),
data.size());
return success();
}
private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;

View File

@ -543,6 +543,12 @@ public:
emitter.emitVarInt(stringSection.insert(str));
}
void writeOwnedBlob(ArrayRef<char> blob) override {
emitter.emitVarInt(blob.size());
emitter.emitOwnedBlob(ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t *>(blob.data()), blob.size()));
}
private:
EncodingEmitter &emitter;
IRNumberingState &numberingState;

View File

@ -39,6 +39,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
// references. This could potentially be useful for optimizing things like
// file locations.
}
void writeOwnedBlob(ArrayRef<char> blob) override {}
/// The parent numbering state that is populated by this writer.
IRNumberingState &state;

View File

@ -123,6 +123,32 @@ enum AttributeCode {
/// handle: ResourceHandle
/// }
kDenseResourceElementsAttr = 16,
/// DenseArrayAttr {
/// type: RankedTensorType,
/// data: blob
/// }
kDenseArrayAttr = 17,
/// DenseIntOrFPElementsAttr {
/// type: ShapedType,
/// data: blob
/// }
kDenseIntOrFPElementsAttr = 18,
/// DenseStringElementsAttr {
/// type: ShapedType,
/// isSplat: varint,
/// data: string[]
/// }
kDenseStringElementsAttr = 19,
/// SparseElementsAttr {
/// type: ShapedType,
/// indices: DenseIntElementsAttr,
/// values: DenseElementsAttr
/// }
kSparseElementsAttr = 20,
};
/// This enum contains marker codes used to indicate which type is currently
@ -279,11 +305,18 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
Attribute readAttribute(DialectBytecodeReader &reader) const override;
ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const;
DenseElementsAttr
readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const;
DenseStringElementsAttr
readDenseStringElementsAttr(DialectBytecodeReader &reader) const;
DenseResourceElementsAttr
readDenseResourceElementsAttr(DialectBytecodeReader &reader) const;
DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
SparseElementsAttr
readSparseElementsAttr(DialectBytecodeReader &reader) const;
StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const;
SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader,
bool hasNestedRefs) const;
@ -298,11 +331,16 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const override;
void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const;
void write(DenseIntOrFPElementsAttr attr,
DialectBytecodeWriter &writer) const;
void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const;
void write(DenseResourceElementsAttr attr,
DialectBytecodeWriter &writer) const;
void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const;
void write(StringAttr attr, DialectBytecodeWriter &writer) const;
void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const;
void write(TypeAttr attr, DialectBytecodeWriter &writer) const;
@ -394,6 +432,14 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
return UnknownLoc::get(getContext());
case builtin_encoding::kDenseResourceElementsAttr:
return readDenseResourceElementsAttr(reader);
case builtin_encoding::kDenseArrayAttr:
return readDenseArrayAttr(reader);
case builtin_encoding::kDenseIntOrFPElementsAttr:
return readDenseIntOrFPElementsAttr(reader);
case builtin_encoding::kDenseStringElementsAttr:
return readDenseStringElementsAttr(reader);
case builtin_encoding::kSparseElementsAttr:
return readSparseElementsAttr(reader);
default:
reader.emitError() << "unknown builtin attribute code: " << code;
return Attribute();
@ -403,8 +449,10 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
Attribute attr, DialectBytecodeWriter &writer) const {
return TypeSwitch<Attribute, LogicalResult>(attr)
.Case<ArrayAttr, DenseResourceElementsAttr, DictionaryAttr, FloatAttr,
IntegerAttr, StringAttr, SymbolRefAttr, TypeAttr>([&](auto attr) {
.Case<ArrayAttr, DenseArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DenseResourceElementsAttr, DictionaryAttr,
FloatAttr, IntegerAttr, SparseElementsAttr, StringAttr,
SymbolRefAttr, TypeAttr>([&](auto attr) {
write(attr, writer);
return success();
})
@ -441,6 +489,78 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeAttributes(attr.getValue());
}
//===----------------------------------------------------------------------===//
// DenseArrayAttr
DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
DialectBytecodeReader &reader) const {
RankedTensorType type;
ArrayRef<char> blob;
if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
return DenseArrayAttr();
return DenseArrayAttr::get(type, blob);
}
void BuiltinDialectBytecodeInterface::write(
DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
writer.writeType(attr.getType());
writer.writeOwnedBlob(attr.getRawData());
}
//===----------------------------------------------------------------------===//
// DenseIntOrFPElementsAttr
DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr(
DialectBytecodeReader &reader) const {
ShapedType type;
ArrayRef<char> blob;
if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
return DenseIntOrFPElementsAttr();
return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob);
}
void BuiltinDialectBytecodeInterface::write(
DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr);
writer.writeType(attr.getType());
writer.writeOwnedBlob(attr.getRawData());
}
//===----------------------------------------------------------------------===//
// DenseStringElementsAttr
DenseStringElementsAttr
BuiltinDialectBytecodeInterface::readDenseStringElementsAttr(
DialectBytecodeReader &reader) const {
ShapedType type;
uint64_t isSplat;
if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat)))
return DenseStringElementsAttr();
SmallVector<StringRef> values(isSplat ? 1 : type.getNumElements());
for (StringRef &value : values)
if (failed(reader.readString(value)))
return DenseStringElementsAttr();
return DenseStringElementsAttr::get(type, values);
}
void BuiltinDialectBytecodeInterface::write(
DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr);
writer.writeType(attr.getType());
bool isSplat = attr.isSplat();
writer.writeVarInt(isSplat);
// If the attribute is a splat, only write out the single value.
if (isSplat)
return writer.writeOwnedString(attr.getRawStringData().front());
for (StringRef str : attr.getRawStringData())
writer.writeOwnedString(str);
}
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
@ -550,6 +670,28 @@ void BuiltinDialectBytecodeInterface::write(
writer.writeAPIntWithKnownWidth(attr.getValue());
}
//===----------------------------------------------------------------------===//
// SparseElementsAttr
SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr(
DialectBytecodeReader &reader) const {
ShapedType type;
DenseIntElementsAttr indices;
DenseElementsAttr values;
if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) ||
failed(reader.readAttribute(values)))
return SparseElementsAttr();
return SparseElementsAttr::get(type, indices, values);
}
void BuiltinDialectBytecodeInterface::write(
SparseElementsAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kSparseElementsAttr);
writer.writeType(attr.getType());
writer.writeAttribute(attr.getIndices());
writer.writeAttribute(attr.getValues());
}
//===----------------------------------------------------------------------===//
// StringAttr

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -emit-bytecode %s | mlir-opt -mlir-print-local-scope | FileCheck %s
// RUN: mlir-opt -emit-bytecode -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect -mlir-print-local-scope | FileCheck %s
// Bytecode currently does not support big-endian platforms
// UNSUPPORTED: s390x-
@ -13,6 +13,44 @@ module @TestArray attributes {
bytecode.array = [unit]
} {}
//===----------------------------------------------------------------------===//
// DenseArrayAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestDenseArray
module @TestDenseArray attributes {
// CHECK: bytecode.test1 = array<i1: true, false, true, false, false>
// CHECK: bytecode.test2 = array<i8: 10, 32, -1>
// CHECK: bytecode.test3 = array<f64: 1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03
bytecode.test1 = array<i1: true, false, true, false, false>,
bytecode.test2 = array<i8: 10, 32, 255>,
bytecode.test3 = array<f64: 10.0, 32.0, 1809.0>
} {}
//===----------------------------------------------------------------------===//
// DenseIntOfFPElementsAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestDenseIntOrFPElements
// CHECK: bytecode.test1 = dense<true> : tensor<256xi1>
// CHECK: bytecode.test2 = dense<[10, 32, -1]> : tensor<3xi8>
// CHECK: bytecode.test3 = dense<[1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03]> : tensor<3xf64>
module @TestDenseIntOrFPElements attributes {
bytecode.test1 = dense<true> : tensor<256xi1>,
bytecode.test2 = dense<[10, 32, 255]> : tensor<3xi8>,
bytecode.test3 = dense<[10.0, 32.0, 1809.0]> : tensor<3xf64>
} {}
//===----------------------------------------------------------------------===//
// DenseStringElementsAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestDenseStringElementsAttr
module @TestDenseStringElementsAttr attributes {
bytecode.test1 = dense<"splat"> : tensor<256x!bytecode.string>,
bytecode.test2 = dense<["foo", "bar", "baz"]> : tensor<3x!bytecode.string>
} {}
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
@ -45,6 +83,17 @@ module @TestInt attributes {
bytecode.int3 = 90000000000000000300000000000000000001 : i128
} {}
//===----------------------------------------------------------------------===//
// SparseElementsAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @TestSparseElements
module @TestSparseElements attributes {
// CHECK-LITERAL: bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>
} {}
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//