forked from OSchip/llvm-project
Add a constant folding hook to ExtractElementOp to fold extracting the element of a constant. This also adds a 'getValue' function to DenseElementsAttr and SparseElementsAttr to get the element at a constant index.
PiperOrigin-RevId: 230098938
This commit is contained in:
parent
119af6712e
commit
512d87cefc
|
@ -348,6 +348,10 @@ public:
|
|||
static DenseElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
@ -427,10 +431,11 @@ public:
|
|||
/// This class uses COO (coordinate list) encoding to represent the sparse
|
||||
/// elements in an element attribute. Specifically, the sparse vector/tensor
|
||||
/// stores the indices and values as two separate dense elements attributes. The
|
||||
/// dense elements attribute indices is a 2-D tensor with shape [N, ndims],
|
||||
/// which specifies the indices of the elements in the sparse tensor that
|
||||
/// contains nonzero values. The dense elements attribute values is a 1-D tensor
|
||||
/// with shape [N], and it supplies the corresponding values for the indices.
|
||||
/// dense elements attribute indices is a 2-D tensor of 64-bit integer elements
|
||||
/// with shape [N, ndims], which specifies the indices of the elements in the
|
||||
/// sparse tensor that contains nonzero values. The dense elements attribute
|
||||
/// values is a 1-D tensor with shape [N], and it supplies the corresponding
|
||||
/// values for the indices.
|
||||
///
|
||||
/// For example,
|
||||
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
|
||||
|
@ -450,8 +455,8 @@ public:
|
|||
|
||||
DenseElementsAttr getValues() const;
|
||||
|
||||
/// Return the value at the given index.
|
||||
Attribute getValue(ArrayRef<unsigned> index) const;
|
||||
/// Return the value of the element at the given index.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
|
||||
|
|
|
@ -523,6 +523,8 @@ public:
|
|||
bool verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class OperationInst;
|
||||
|
|
|
@ -149,6 +149,53 @@ Attribute SplatElementsAttr::getValue() const {
|
|||
|
||||
/// DenseElementsAttr
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
auto type = getType();
|
||||
|
||||
// Verify that the rank of the indices matches the held type.
|
||||
auto rank = type.getRank();
|
||||
if (rank != index.size())
|
||||
return Attribute();
|
||||
|
||||
// Verify that all of the indices are within the shape dimensions.
|
||||
auto shape = type.getShape();
|
||||
for (unsigned i = 0; i != rank; ++i)
|
||||
if (shape[i] <= index[i])
|
||||
return Attribute();
|
||||
|
||||
// Reduce the provided multidimensional index into a 1D index.
|
||||
uint64_t valueIndex = 0;
|
||||
uint64_t dimMultiplier = 1;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
valueIndex += index[i] * dimMultiplier;
|
||||
dimMultiplier *= shape[i];
|
||||
}
|
||||
|
||||
// Return the element stored at the 1D index.
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
auto elementType = getType().getElementType();
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
APInt rawValueData =
|
||||
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
|
||||
|
||||
// Convert the raw value data to an attribute value.
|
||||
switch (getKind()) {
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
return IntegerAttr::get(elementType, rawValueData);
|
||||
case Attribute::Kind::DenseFPElements:
|
||||
return FloatAttr::get(
|
||||
elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
|
||||
rawValueData));
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
}
|
||||
|
||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementType = getType().getElementType();
|
||||
switch (getKind()) {
|
||||
|
@ -188,8 +235,8 @@ void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
|||
auto elementNum = getType().getNumElements();
|
||||
values.reserve(elementNum);
|
||||
|
||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
||||
// semantics.
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
const auto *rawData = getRawData().data();
|
||||
|
@ -281,3 +328,40 @@ DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
|||
DenseElementsAttr SparseElementsAttr::getValues() const {
|
||||
return static_cast<ImplType *>(attr)->values;
|
||||
}
|
||||
|
||||
/// Return the value of the element at the given index.
|
||||
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
auto type = getType();
|
||||
|
||||
// Verify that the rank of the indices matches the held type.
|
||||
auto rank = type.getRank();
|
||||
if (rank != index.size())
|
||||
return Attribute();
|
||||
|
||||
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||
// as a 1-D index array.
|
||||
auto sparseIndices = getIndices();
|
||||
const uint64_t *sparseIndexValues =
|
||||
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
|
||||
|
||||
// Build a mapping between known indices and the offset of the stored element.
|
||||
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
||||
size_t numSparseIndices = sparseIndices.getType().getDimSize(0);
|
||||
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
||||
mappedIndices.try_emplace(
|
||||
{sparseIndexValues + (i * rank), static_cast<size_t>(rank)}, i);
|
||||
|
||||
// Look for the provided index key within the mapped indices. If the provided
|
||||
// index is not found, then return a zero attribute.
|
||||
auto it = mappedIndices.find(index);
|
||||
if (it == mappedIndices.end()) {
|
||||
auto eltType = type.getElementType();
|
||||
if (eltType.isa<FloatType>())
|
||||
return FloatAttr::get(eltType, 0);
|
||||
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
||||
return IntegerAttr::get(eltType, 0);
|
||||
}
|
||||
|
||||
// Otherwise, return the held sparse value element.
|
||||
return getValues().getValue(it->second);
|
||||
}
|
||||
|
|
|
@ -1074,8 +1074,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
assert(values.size() == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
|
||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
||||
// semantics.
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
|
||||
|
@ -1136,6 +1136,9 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
|
|||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
assert(indices.getType().getElementType().isInteger(64) &&
|
||||
"expected sparse indices to be 64-bit integer values");
|
||||
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
|
|
|
@ -1076,7 +1076,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
switch (getToken().getKind()) {
|
||||
case Token::l_square: {
|
||||
/// Parse indices
|
||||
auto indicesEltType = builder.getIntegerType(32);
|
||||
auto indicesEltType = builder.getIntegerType(64);
|
||||
auto indices =
|
||||
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
|
||||
if (!indices)
|
||||
|
|
|
@ -1136,6 +1136,41 @@ bool ExtractElementOp::verify() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() > 1 && "extract_element takes atleast one operands");
|
||||
|
||||
// The aggregate operand must be a known constant.
|
||||
Attribute aggregate = operands.front();
|
||||
if (!aggregate)
|
||||
return Attribute();
|
||||
|
||||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
|
||||
return splatAggregate.getValue();
|
||||
|
||||
// Otherwise, collect the constant indices into the aggregate.
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
||||
if (!indice || !indice.isa<IntegerAttr>())
|
||||
return Attribute();
|
||||
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
// Get the element value of the aggregate attribute with the given constant
|
||||
// indices.
|
||||
switch (aggregate.getKind()) {
|
||||
case Attribute::Kind::DenseFPElements:
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
return aggregate.cast<DenseElementsAttr>().getValue(indices);
|
||||
case Attribute::Kind::SparseElements:
|
||||
return aggregate.cast<SparseElementsAttr>().getValue(indices);
|
||||
default:
|
||||
return Attribute();
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -320,6 +320,36 @@ func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
|
|||
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_extract_element
|
||||
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||
%const_0 = constant 0 : index
|
||||
%const_1 = constant 1 : index
|
||||
%const_3 = constant 3 : index
|
||||
|
||||
// Fold an extract into a splat.
|
||||
// CHECK-NEXT: {{.*}} = constant 4.500000e+00 : f32
|
||||
%0 = constant splat<tensor<4xf32>, 4.5> : tensor<4xf32>
|
||||
%ext_1 = extract_element %0[%arg0] : tensor<4xf32>
|
||||
|
||||
// Fold an extract into a sparse with a sparse index.
|
||||
// CHECK-NEXT: {{.*}} = constant -2.000000e+00 : f16
|
||||
%1 = constant sparse<vector<1x1x1xf16>, [[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : vector<1x1x1xf16>
|
||||
%ext_2 = extract_element %1[%const_1, %const_1, %const_1] : vector<1x1x1xf16>
|
||||
|
||||
// Fold an extract into a sparse with a non sparse index.
|
||||
// CHECK-NEXT: {{.*}} = constant 0.000000e+00 : f16
|
||||
%2 = constant sparse<vector<1x1x1xf16>, [[1, 1, 1]], [-2.0]> : vector<1x1x1xf16>
|
||||
%ext_3 = extract_element %2[%const_0, %const_0, %const_0] : vector<1x1x1xf16>
|
||||
|
||||
// Fold an extract into a dense tensor.
|
||||
// CHECK-NEXT: {{.*}} = constant 64 : i32
|
||||
%3 = constant dense<tensor<2x1x4xi32>, [[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
|
||||
%ext_4 = extract_element %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
|
||||
|
||||
// CHECK-NEXT: return
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine_apply (i) -> (i mod 42) to standard operations. Please only
|
||||
|
|
Loading…
Reference in New Issue