Add support to constant splat vector/tensor attribute.

This attribute represents a reference to a splat vector or tensor, where all
the elements have the same value. The syntax of the attribute is:

`splat<` (tensor-type | vector-type)`,` attribute-value `>`

PiperOrigin-RevId: 216537997
This commit is contained in:
Feng Liu 2018-10-10 08:57:51 -07:00 committed by jpienaar
parent fd06c6bc4e
commit 5e3cca906a
10 changed files with 149 additions and 9 deletions

View File

@ -27,6 +27,7 @@ class Function;
class FunctionType;
class MLIRContext;
class Type;
class VectorOrTensorType;
/// Attributes are known-constant values of operations and functions.
///
@ -43,6 +44,10 @@ public:
Array,
AffineMap,
Function,
SplatElements,
FIRST_ELEMENTS_ATTR = SplatElements,
LAST_ELEMENTS_ATTR = SplatElements,
};
/// Return the classification for this attribute.
@ -249,6 +254,41 @@ private:
Function *value;
};
/// A base attribute represents a reference to a vector or tensor constant.
class ElementsAttr : public Attribute {
public:
ElementsAttr(Kind kind, VectorOrTensorType *type)
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {}
VectorOrTensorType *getType() const { return type; }
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR &&
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR;
}
private:
VectorOrTensorType *type;
};
/// An attribute represents a reference to a splat vecctor or tensor constant,
/// meaning all of the elements have the same value.
class SplatElementsAttr : public ElementsAttr {
public:
static ElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
Attribute *getValue() const { return elt; }
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::SplatElements;
}
private:
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
Attribute *elt;
};
} // end namespace mlir.
#endif

View File

@ -43,6 +43,7 @@ class StringAttr;
class TypeAttr;
class ArrayAttr;
class FunctionAttr;
class ElementsAttr;
class AffineMapAttr;
class AffineMap;
@ -98,6 +99,7 @@ public:
AffineMapAttr *getAffineMapAttr(AffineMap map);
TypeAttr *getTypeAttr(Type *type);
FunctionAttr *getFunctionAttr(const Function *value);
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);

View File

@ -303,10 +303,7 @@ public:
/// If any dimension has unknown size (<0), it doesn't have static shape.
/// If all dimensions has known size (>= 0), it has static shape.
bool hasStaticShape() const {
auto dims = getShape();
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
bool hasStaticShape() const;
/// If this is ranked tensor or vector type, return the size of the specified
/// dimension. It aborts if the tensor is unranked (this can be checked by

View File

@ -456,6 +456,15 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
}
break;
}
case Attribute::Kind::SplatElements: {
auto *elementsAttr = cast<SplatElementsAttr>(attr);
os << "splat<";
printType(elementsAttr->getType());
os << ", ";
printAttribute(elementsAttr->getValue());
os << '>';
break;
}
}
}

View File

@ -144,6 +144,11 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
Attribute *elt) {
return SplatElementsAttr::get(type, elt);
}
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integet Sets.
//===----------------------------------------------------------------------===//

View File

@ -274,6 +274,8 @@ public:
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttr *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
splatElementsAttrs;
public:
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {
@ -775,7 +777,6 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
}
}
// Ok, now that we've canonicalized our attributes, unique them.
auto &impl = context->getImpl();
// Look to see if we already have this.
@ -797,6 +798,23 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) {
auto &impl = type->getContext()->getImpl();
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
// If we already have it, return that value.
if (result)
return result;
// Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SplatElementsAttr>();
new (result) SplatElementsAttr(type, elt);
return result;
}
//===----------------------------------------------------------------------===//
// AffineMap and AffineExpr uniquing
//===----------------------------------------------------------------------===//

View File

@ -47,9 +47,8 @@ int VectorOrTensorType::getRank() const {
default:
llvm_unreachable("not a VectorOrTensorType");
case Kind::Vector:
return cast<VectorType>(this)->getShape().size();
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape().size();
return getShape().size();
case Kind::UnrankedTensor:
return -1;
}
@ -58,14 +57,31 @@ int VectorOrTensorType::getRank() const {
int VectorOrTensorType::getDimSize(unsigned i) const {
switch (getKind()) {
case Kind::Vector:
return cast<VectorType>(this)->getShape()[i];
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape()[i];
return getShape()[i];
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
return cast<VectorType>(this)->getShape();
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape();
case Kind::UnrankedTensor:
return cast<RankedTensorType>(this)->getShape();
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
bool VectorOrTensorType::hasStaticShape() const {
auto dims = getShape();
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),

View File

@ -658,6 +658,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | function-id `:` function-type
/// | `splat<` (tensor-type | vector-type)`,`
/// attribute-value `>`
///
Attribute *Parser::parseAttribute() {
switch (getToken().getKind()) {
@ -752,6 +754,42 @@ Attribute *Parser::parseAttribute() {
return function ? builder.getFunctionAttr(function) : nullptr;
}
case Token::kw_splat: {
consumeToken(Token::kw_splat);
if (parseToken(Token::less, "Expected '<' after 'elements'"))
return nullptr;
auto *type = dyn_cast<VectorOrTensorType>(parseType());
if (!type) {
return (
emitError("expected elements literal has a tensor or vector type"),
nullptr);
}
if (parseToken(Token::comma, "Expected ','"))
return nullptr;
if (!type->hasStaticShape() || type->getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
switch (getToken().getKind()) {
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto *scalar = parseAttribute();
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getSplatElementsAttr(type, scalar);
}
default:
return (
emitError("expected '[' or scalar constant inside tensor literal"),
nullptr);
}
}
default: {
if (Type *type = parseType())
return builder.getTypeAttr(type);

View File

@ -94,6 +94,7 @@ TOK_KEYWORD(ceildiv)
TOK_KEYWORD(cfgfunc)
TOK_KEYWORD(cond_br)
TOK_KEYWORD(else)
TOK_KEYWORD(splat)
TOK_KEYWORD(extfunc)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)

View File

@ -484,3 +484,17 @@ mlfunc @mlfuncsimplemap(%arg0 : index, %arg1 : index) -> () {
}
return
}
// CHECK-LABEL: cfgfunc @tensorattr
cfgfunc @tensorattr() -> () {
bb0:
// CHECK: "splatIntTensor"() {bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
"splatIntTensor"(){bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
// CHECK: "splatFloatTensor"() {bar: splat<tensor<2x1x4xf32>, -5.000000e+00>} : () -> ()
"splatFloatTensor"(){bar: splat<tensor<2x1x4xf32>, -5.0>} : () -> ()
// CHECK: "splatIntVector"() {bar: splat<vector<2x1x4xi64>, 5>} : () -> ()
"splatIntVector"(){bar: splat<vector<2x1x4xi64>, 5>} : () -> ()
// CHECK: "splatFloatVector"() {bar: splat<vector<2x1x4xf16>, -5.000000e+00>} : () -> ()
"splatFloatVector"(){bar: splat<vector<2x1x4xf16>, -5.0>} : () -> ()
return
}