Add a constant folding hook to ExtractElementOp to fold extracting the element of a constant. This also adds a 'getValue' function to DenseElementsAttr and SparseElementsAttr to get the element at a constant index.

PiperOrigin-RevId: 230098938
This commit is contained in:
River Riddle 2019-01-19 20:54:09 -08:00 committed by jpienaar
parent 119af6712e
commit 512d87cefc
7 changed files with 170 additions and 11 deletions

View File

@ -348,6 +348,10 @@ public:
static DenseElementsAttr get(VectorOrTensorType type,
ArrayRef<Attribute> values);
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
void getValues(SmallVectorImpl<Attribute> &values) const;
ArrayRef<char> getRawData() const;
@ -427,10 +431,11 @@ public:
/// This class uses COO (coordinate list) encoding to represent the sparse
/// elements in an element attribute. Specifically, the sparse vector/tensor
/// stores the indices and values as two separate dense elements attributes. The
/// dense elements attribute indices is a 2-D tensor with shape [N, ndims],
/// which specifies the indices of the elements in the sparse tensor that
/// contains nonzero values. The dense elements attribute values is a 1-D tensor
/// with shape [N], and it supplies the corresponding values for the indices.
/// dense elements attribute indices is a 2-D tensor of 64-bit integer elements
/// with shape [N, ndims], which specifies the indices of the elements in the
/// sparse tensor that contains nonzero values. The dense elements attribute
/// values is a 1-D tensor with shape [N], and it supplies the corresponding
/// values for the indices.
///
/// For example,
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
@ -450,8 +455,8 @@ public:
DenseElementsAttr getValues() const;
/// Return the value at the given index.
Attribute getValue(ArrayRef<unsigned> index) const;
/// Return the value of the element at the given index.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }

View File

@ -523,6 +523,8 @@ public:
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class OperationInst;

View File

@ -149,6 +149,53 @@ Attribute SplatElementsAttr::getValue() const {
/// DenseElementsAttr
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
auto type = getType();
// Verify that the rank of the indices matches the held type.
auto rank = type.getRank();
if (rank != index.size())
return Attribute();
// Verify that all of the indices are within the shape dimensions.
auto shape = type.getShape();
for (unsigned i = 0; i != rank; ++i)
if (shape[i] <= index[i])
return Attribute();
// Reduce the provided multidimensional index into a 1D index.
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
valueIndex += index[i] * dimMultiplier;
dimMultiplier *= shape[i];
}
// Return the element stored at the 1D index.
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
auto elementType = getType().getElementType();
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
APInt rawValueData =
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
// Convert the raw value data to an attribute value.
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
return IntegerAttr::get(elementType, rawValueData);
case Attribute::Kind::DenseFPElements:
return FloatAttr::get(
elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
rawValueData));
default:
llvm_unreachable("unexpected element type");
}
}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementType = getType().getElementType();
switch (getKind()) {
@ -188,8 +235,8 @@ void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
auto elementNum = getType().getNumElements();
values.reserve(elementNum);
// FIXME: using 64 bits for BF16 because it is currently stored with double
// semantics.
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
const auto *rawData = getRawData().data();
@ -281,3 +328,40 @@ DenseIntElementsAttr SparseElementsAttr::getIndices() const {
DenseElementsAttr SparseElementsAttr::getValues() const {
return static_cast<ImplType *>(attr)->values;
}
/// Return the value of the element at the given index.
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
auto type = getType();
// Verify that the rank of the indices matches the held type.
auto rank = type.getRank();
if (rank != index.size())
return Attribute();
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
// as a 1-D index array.
auto sparseIndices = getIndices();
const uint64_t *sparseIndexValues =
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
// Build a mapping between known indices and the offset of the stored element.
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
size_t numSparseIndices = sparseIndices.getType().getDimSize(0);
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
mappedIndices.try_emplace(
{sparseIndexValues + (i * rank), static_cast<size_t>(rank)}, i);
// Look for the provided index key within the mapped indices. If the provided
// index is not found, then return a zero attribute.
auto it = mappedIndices.find(index);
if (it == mappedIndices.end()) {
auto eltType = type.getElementType();
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
assert(eltType.isa<IntegerType>() && "unexpected element type");
return IntegerAttr::get(eltType, 0);
}
// Otherwise, return the held sparse value element.
return getValues().getValue(it->second);
}

View File

@ -1074,8 +1074,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
assert(values.size() == type.getNumElements() &&
"expected 'values' to contain the same number of elements as 'type'");
// FIXME: using 64 bits for BF16 because it is currently stored with double
// semantics.
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
auto eltType = type.getElementType();
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
@ -1136,6 +1136,9 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.

View File

@ -1076,7 +1076,7 @@ Attribute Parser::parseAttribute(Type type) {
switch (getToken().getKind()) {
case Token::l_square: {
/// Parse indices
auto indicesEltType = builder.getIntegerType(32);
auto indicesEltType = builder.getIntegerType(64);
auto indices =
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
if (!indices)

View File

@ -1136,6 +1136,41 @@ bool ExtractElementOp::verify() const {
return false;
}
Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() > 1 && "extract_element takes atleast one operands");
// The aggregate operand must be a known constant.
Attribute aggregate = operands.front();
if (!aggregate)
return Attribute();
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
return splatAggregate.getValue();
// Otherwise, collect the constant indices into the aggregate.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return Attribute();
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// Get the element value of the aggregate attribute with the given constant
// indices.
switch (aggregate.getKind()) {
case Attribute::Kind::DenseFPElements:
case Attribute::Kind::DenseIntElements:
return aggregate.cast<DenseElementsAttr>().getValue(indices);
case Attribute::Kind::SparseElements:
return aggregate.cast<SparseElementsAttr>().getValue(indices);
default:
return Attribute();
}
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

View File

@ -320,6 +320,36 @@ func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
// CHECK-LABEL: func @fold_extract_element
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
%const_0 = constant 0 : index
%const_1 = constant 1 : index
%const_3 = constant 3 : index
// Fold an extract into a splat.
// CHECK-NEXT: {{.*}} = constant 4.500000e+00 : f32
%0 = constant splat<tensor<4xf32>, 4.5> : tensor<4xf32>
%ext_1 = extract_element %0[%arg0] : tensor<4xf32>
// Fold an extract into a sparse with a sparse index.
// CHECK-NEXT: {{.*}} = constant -2.000000e+00 : f16
%1 = constant sparse<vector<1x1x1xf16>, [[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : vector<1x1x1xf16>
%ext_2 = extract_element %1[%const_1, %const_1, %const_1] : vector<1x1x1xf16>
// Fold an extract into a sparse with a non sparse index.
// CHECK-NEXT: {{.*}} = constant 0.000000e+00 : f16
%2 = constant sparse<vector<1x1x1xf16>, [[1, 1, 1]], [-2.0]> : vector<1x1x1xf16>
%ext_3 = extract_element %2[%const_0, %const_0, %const_0] : vector<1x1x1xf16>
// Fold an extract into a dense tensor.
// CHECK-NEXT: {{.*}} = constant 64 : i32
%3 = constant dense<tensor<2x1x4xi32>, [[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = extract_element %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
// CHECK-NEXT: return
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine_apply (i) -> (i mod 42) to standard operations. Please only