2018-10-26 06:46:10 +08:00
|
|
|
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "AttributeDetail.h"
|
2018-10-26 13:13:03 +08:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2018-10-26 06:46:10 +08:00
|
|
|
#include "mlir/IR/Function.h"
|
2018-10-26 13:13:03 +08:00
|
|
|
#include "mlir/IR/IntegerSet.h"
|
2018-10-26 06:46:10 +08:00
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::detail;
|
|
|
|
|
|
|
|
Attribute::Kind Attribute::getKind() const { return attr->kind; }
|
|
|
|
|
|
|
|
bool Attribute::isOrContainsFunction() const {
|
|
|
|
return attr->isOrContainsFunctionCache;
|
|
|
|
}
|
|
|
|
|
2018-11-15 06:26:47 +08:00
|
|
|
// Given an attribute that could refer to a function attribute in the remapping
|
|
|
|
// table, walk it and rewrite it to use the mapped function. If it doesn't
|
|
|
|
// refer to anything in the table, then it is returned unmodified.
|
|
|
|
Attribute Attribute::remapFunctionAttrs(
|
|
|
|
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
// Most attributes are trivially unrelated to function attributes, skip them
|
|
|
|
// rapidly.
|
|
|
|
if (!isOrContainsFunction())
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
// If we have a function attribute, remap it.
|
|
|
|
if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
|
|
|
|
auto it = remappingTable.find(fnAttr);
|
|
|
|
return it != remappingTable.end() ? it->second : *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, we must have an array attribute, remap the elements.
|
|
|
|
auto arrayAttr = this->cast<ArrayAttr>();
|
|
|
|
SmallVector<Attribute, 8> remappedElts;
|
|
|
|
bool anyChange = false;
|
|
|
|
for (auto elt : arrayAttr.getValue()) {
|
|
|
|
auto newElt = elt.remapFunctionAttrs(remappingTable, context);
|
|
|
|
remappedElts.push_back(newElt);
|
|
|
|
anyChange |= (elt != newElt);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!anyChange)
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
return ArrayAttr::get(remappedElts, context);
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// BoolAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// IntegerAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-11-12 22:33:22 +08:00
|
|
|
APInt IntegerAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->getValue();
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2018-11-12 22:33:22 +08:00
|
|
|
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
Type IntegerAttr::getType() const {
|
|
|
|
return static_cast<ImplType *>(attr)->type;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// FloatAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
APFloat FloatAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->getValue();
|
|
|
|
}
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
|
|
|
|
2018-12-28 08:51:09 +08:00
|
|
|
double FloatAttr::getValueAsDouble() const {
|
|
|
|
const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
|
|
|
|
auto value = getValue();
|
|
|
|
bool losesInfo = false; // ignored
|
|
|
|
if (&semantics != &APFloat::IEEEdouble()) {
|
|
|
|
value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
|
|
|
|
&losesInfo);
|
|
|
|
}
|
|
|
|
return value.convertToDouble();
|
|
|
|
}
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// StringAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
StringRef StringAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// ArrayAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
ArrayRef<Attribute> ArrayAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// AffineMapAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
AffineMap AffineMapAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// IntegerSetAttr
|
2018-10-26 13:13:03 +08:00
|
|
|
|
|
|
|
IntegerSet IntegerSetAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// TypeAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// FunctionAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
Function *FunctionAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// ElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
VectorOrTensorType ElementsAttr::getType() const {
|
2018-10-26 06:46:10 +08:00
|
|
|
return static_cast<ImplType *>(attr)->type;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// SplatElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
Attribute SplatElementsAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->elt;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|
|
|
switch (getKind()) {
|
|
|
|
case Attribute::Kind::DenseIntElements:
|
|
|
|
cast<DenseIntElementsAttr>().getValues(values);
|
|
|
|
return;
|
|
|
|
case Attribute::Kind::DenseFPElements:
|
|
|
|
cast<DenseFPElementsAttr>().getValues(values);
|
|
|
|
return;
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unexpected element type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|
|
|
return static_cast<ImplType *>(attr)->data;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
|
|
|
/// starting from `rawData`.
|
2018-12-18 21:25:17 +08:00
|
|
|
void DenseElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
|
|
|
uint64_t value) {
|
2019-01-15 05:18:34 +08:00
|
|
|
assert(bitWidth <= 64 && "expected bitWidth to be within 64-bits");
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
// Read the destination bytes which will be written to.
|
|
|
|
uint64_t dst = 0;
|
|
|
|
auto dstData = reinterpret_cast<char *>(&dst);
|
|
|
|
auto endPos = bitPos + bitWidth;
|
|
|
|
auto start = data + bitPos / 8;
|
|
|
|
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
|
|
|
std::copy(start, end, dstData);
|
|
|
|
|
|
|
|
// Clean up the invalid bits in the destination bytes.
|
|
|
|
dst &= ~(-1UL << (bitPos % 8));
|
|
|
|
|
|
|
|
// Get the valid bits of the source value, shift them to right position,
|
|
|
|
// then add them to the destination bytes.
|
|
|
|
value <<= bitPos % 8;
|
|
|
|
dst |= value;
|
|
|
|
|
|
|
|
// Write the destination bytes back.
|
|
|
|
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
|
|
|
std::copy(range.begin(), range.end(), start);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
|
|
|
/// and put them in the lowest bits.
|
2018-12-18 21:25:17 +08:00
|
|
|
uint64_t DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|
|
|
size_t bitsWidth) {
|
2019-01-15 05:18:34 +08:00
|
|
|
assert(bitsWidth <= 64 && "expected bitWidth to be within 64-bits");
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
uint64_t dst = 0;
|
|
|
|
auto dstData = reinterpret_cast<char *>(&dst);
|
|
|
|
auto endPos = bitPos + bitsWidth;
|
|
|
|
auto start = rawData + bitPos / 8;
|
|
|
|
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
|
|
|
std::copy(start, end, dstData);
|
|
|
|
|
|
|
|
dst >>= bitPos % 8;
|
|
|
|
dst &= ~(-1UL << bitsWidth);
|
|
|
|
return dst;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseIntElementsAttr
|
2018-12-18 21:25:17 +08:00
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|
|
|
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
2018-10-31 05:59:22 +08:00
|
|
|
auto elementNum = getType().getNumElements();
|
2018-10-26 06:46:10 +08:00
|
|
|
values.reserve(elementNum);
|
|
|
|
if (bitsWidth == 64) {
|
|
|
|
ArrayRef<int64_t> vs(
|
|
|
|
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
|
|
|
getRawData().size() / 8});
|
|
|
|
for (auto value : vs) {
|
2018-11-16 09:53:51 +08:00
|
|
|
auto attr = IntegerAttr::get(getType().getElementType(), value);
|
2018-10-26 06:46:10 +08:00
|
|
|
values.push_back(attr);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
const auto *rawData = getRawData().data();
|
|
|
|
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
|
|
|
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
|
|
|
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
2018-11-16 09:53:51 +08:00
|
|
|
auto attr =
|
|
|
|
IntegerAttr::get(getType().getElementType(), value.getSExtValue());
|
2018-10-26 06:46:10 +08:00
|
|
|
values.push_back(attr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseFPElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-12-18 21:25:17 +08:00
|
|
|
// Construct a FloatAttr wrapping a float value of `elementType` type from its
|
|
|
|
// bit representation. The APFloat stored in the attribute will have the
|
|
|
|
// semantics defined by the float semantics of the element type.
|
|
|
|
static inline FloatAttr makeFloatAttrFromBits(size_t bitWidth, uint64_t bits,
|
|
|
|
FloatType elementType) {
|
|
|
|
auto apint = APInt(bitWidth, bits);
|
|
|
|
auto apfloat = APFloat(elementType.getFloatSemantics(), apint);
|
|
|
|
return FloatAttr::get(elementType, apfloat);
|
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
2018-10-31 05:59:22 +08:00
|
|
|
auto elementNum = getType().getNumElements();
|
2018-12-18 21:25:17 +08:00
|
|
|
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
|
|
|
assert(elementType && "non-float type in FP attribute");
|
|
|
|
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
|
|
|
// semantics.
|
|
|
|
size_t bitWidth =
|
|
|
|
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
values.reserve(elementNum);
|
2018-12-18 21:25:17 +08:00
|
|
|
if (bitWidth == 64) {
|
|
|
|
ArrayRef<int64_t> vs(
|
|
|
|
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
|
|
|
getRawData().size() / 8});
|
|
|
|
for (auto bitValue : vs) {
|
|
|
|
values.push_back(makeFloatAttrFromBits(64, bitValue, elementType));
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
for (unsigned i = 0; i < elementNum; ++i) {
|
|
|
|
uint64_t bits = readBits(getRawData().data(), i * bitWidth, bitWidth);
|
|
|
|
values.push_back(makeFloatAttrFromBits(bitWidth, bits, elementType));
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// OpaqueElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
StringRef OpaqueElementsAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->bytes;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// SparseElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
|
|
|
return static_cast<ImplType *>(attr)->indices;
|
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr SparseElementsAttr::getValues() const {
|
|
|
|
return static_cast<ImplType *>(attr)->values;
|
|
|
|
}
|