Give custom ops the ability to also access general additional attributes in the

parser and printer.  Fix the spelling of 'delimeter'

PiperOrigin-RevId: 207189892
This commit is contained in:
Chris Lattner 2018-08-02 16:54:36 -07:00 committed by jpienaar
parent 6472f5fbbb
commit 316e884367
8 changed files with 184 additions and 91 deletions

View File

@ -99,6 +99,9 @@ public:
const Operation *getOperation() const { return state; }
Operation *getOperation() { return state; }
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
/// Return an attribute with the specified name.
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }

View File

@ -69,6 +69,14 @@ public:
virtual void printAffineMap(const AffineMap *map) = 0;
virtual void printAffineExpr(const AffineExpr *expr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
/// printed some other way (like as a fixed operand).
virtual void
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) = 0;
/// Print the entire operation with the default verbose formatting.
virtual void printDefaultOp(const Operation *op) = 0;
@ -127,7 +135,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
///
/// The "%x = load" tokens are already parsed and therefore invisible to the
/// custom op parser. This can be supported by calling `parseOperandList` to
/// parse the %p, then calling `parseOperandList` with a `SquareDelimeter` to
/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
/// parse the indices, then calling `parseColonTypeList` to parse the result
/// type.
///
@ -174,17 +182,23 @@ public:
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result,
llvm::SMLoc *loc = nullptr) = 0;
/// Parse an attribute.
virtual bool parseAttribute(Attribute *&result,
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
virtual bool parseAttribute(Attribute *&result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs,
llvm::SMLoc *loc = nullptr) = 0;
/// Parse an attribute of a specific kind.
/// Parse an attribute of a specific kind, capturing the location into `loc`
/// if specified.
template <typename AttrType>
bool parseAttribute(AttrType *&result, llvm::SMLoc *loc = nullptr) {
bool parseAttribute(AttrType *&result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs,
llvm::SMLoc *loc = nullptr) {
// Parse any kind of attribute.
Attribute *attr;
llvm::SMLoc tmpLoc;
if (parseAttribute(attr, &tmpLoc))
if (parseAttribute(attr, attrName, attrs, &tmpLoc))
return true;
if (loc)
*loc = tmpLoc;
@ -199,6 +213,11 @@ public:
return false;
}
/// If a named attribute dictionary is present, parse it into result.
virtual bool
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
llvm::SMLoc *loc = nullptr) = 0;
/// This is the representation of an operand reference.
struct OperandType {
llvm::SMLoc location; // Location of the token.
@ -209,27 +228,26 @@ public:
/// Parse a single operand.
virtual bool parseOperand(OperandType &result) = 0;
/// These are the supported delimeters around operand lists, used by
/// These are the supported delimiters around operand lists, used by
/// parseOperandList.
enum Delimeter {
/// Zero or more operands with no delimeters.
NoDelimeter,
enum Delimiter {
/// Zero or more operands with no delimiters.
None,
/// Parens surrounding zero or more operands.
ParenDelimeter,
Paren,
/// Square brackets surrounding zero or more operands.
SquareDelimeter,
Square,
/// Parens supporting zero or more operands, or nothing.
OptionalParenDelimeter,
OptionalParen,
/// Square brackets supporting zero or more ops, or nothing.
OptionalSquareDelimeter,
OptionalSquare,
};
/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimeter, and an optional required operand count.
virtual bool
parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimeter delimeter = Delimeter::NoDelimeter) = 0;
/// surrounding delimiter, and an optional required operand count.
virtual bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser

View File

@ -101,6 +101,7 @@ public:
// (maybe a dozen or so, but not hundreds or thousands) so we use linear
// searches for everything.
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
/// Return the specified attribute if present, null otherwise.

View File

@ -579,6 +579,8 @@ public:
}
void printOperand(const SSAValue *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) override;
enum { nameSentinel = ~0U };
@ -711,6 +713,44 @@ private:
};
} // end anonymous namespace
void FunctionPrinter::printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
auto attrName = attr.first.str();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
continue;
// If the caller has requested that this attribute be ignored, then drop it.
bool ignore = false;
for (const char *elide : elidedAttrs)
ignore |= attrName == StringRef(elide);
// Otherwise add it to our filteredAttrs list.
if (!ignore)
filteredAttrs.push_back(attr);
}
// If there are no attributes left to print after filtering, then we're done.
if (filteredAttrs.empty())
return;
// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
printAttribute(attr.second);
});
os << '}';
}
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
@ -737,14 +777,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
os << ')';
auto attrs = op->getAttrs();
if (!attrs.empty()) {
os << '{';
interleaveComma(attrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
printAttribute(attr.second);
});
os << '}';
}
printOptionalAttrDict(attrs);
// Print the type signature of the operation.
os << " : (";

View File

@ -46,16 +46,15 @@ static bool
parseDimAndSymbolList(OpAsmParser *parser,
SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimeter::ParenDelimeter))
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto *affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperandList(
opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
@ -67,17 +66,21 @@ OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
SSAValue *lhs, *rhs;
if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
SmallVector<NamedAttribute, 4> attrs;
if (parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperand(ops[0], type, lhs) ||
parser->resolveOperand(ops[1], type, rhs))
return {};
return OpAsmParserResult({lhs, rhs}, type);
return OpAsmParserResult({lhs, rhs}, type, attrs);
}
void AddFOp::print(OpAsmPrinter *p) const {
*p << "addf " << *getOperand(0) << ", " << *getOperand(1) << " : "
<< *getType();
*p << "addf " << *getOperand(0) << ", " << *getOperand(1);
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getType();
}
// Return an error message on failure.
@ -91,14 +94,16 @@ const char *AddFOp::verify() const {
OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 2> opInfos;
SmallVector<SSAValue *, 4> operands;
SmallVector<NamedAttribute, 4> attrs;
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
unsigned numDims;
if (parser->parseAttribute(mapAttr) ||
parseDimAndSymbolList(parser, opInfos, operands, numDims))
if (parser->parseAttribute(mapAttr, "map", attrs) ||
parseDimAndSymbolList(parser, opInfos, operands, numDims) ||
parser->parseOptionalAttributeDict(attrs))
return {};
auto *map = mapAttr->getValue();
@ -110,15 +115,14 @@ OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
}
SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
return OpAsmParserResult(
operands, resultTypes,
NamedAttribute(builder.getIdentifier("map"), mapAttr));
return OpAsmParserResult(operands, resultTypes, attrs);
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
const char *AffineApplyOp::verify() const {
@ -147,7 +151,7 @@ void AllocOp::print(OpAsmPrinter *p) const {
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
type->getNumDynamicDims(), p);
// Print memref type.
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
*p << " : " << *type;
}
@ -155,12 +159,13 @@ OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
SmallVector<NamedAttribute, 4> attrs;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
unsigned numDimOperands;
if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands) ||
parser->parseColonType(type))
parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
return {};
// Check numDynamicDims against number of question marks in memref type.
@ -182,7 +187,7 @@ OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
return {};
}
return OpAsmParserResult(operands, type);
return OpAsmParserResult(operands, type, attrs);
}
const char *AllocOp::verify() const {
@ -191,19 +196,20 @@ const char *AllocOp::verify() const {
}
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue() << " : " << *getType();
*p << "constant " << *getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
*p << " : " << *getType();
}
OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
Attribute *valueAttr;
Type *type;
if (parser->parseAttribute(valueAttr) || parser->parseColonType(type))
return {};
SmallVector<NamedAttribute, 4> attrs;
auto &builder = parser->getBuilder();
return OpAsmParserResult(
/*operands=*/{}, type,
NamedAttribute(builder.getIdentifier("value"), valueAttr));
if (parser->parseAttribute(valueAttr, "value", attrs) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type))
return {};
return OpAsmParserResult(/*operands=*/{}, type, attrs);
}
/// The constant op requires an attribute, and furthermore requires that it
@ -236,8 +242,9 @@ bool ConstantIntOp::isClassFor(const Operation *op) {
}
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex() << " : "
<< *getOperand()->getType();
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
*p << " : " << *getOperand()->getType();
}
OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
@ -245,15 +252,17 @@ OpAsmParserResult DimOp::parse(OpAsmParser *parser) {
IntegerAttr *indexAttr;
Type *type;
SSAValue *operand;
SmallVector<NamedAttribute, 4> attrs;
if (parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr) || parser->parseColonType(type) ||
parser->parseAttribute(indexAttr, "index", attrs) ||
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, operand))
return {};
auto &builder = parser->getBuilder();
return OpAsmParserResult(
operand, builder.getAffineIntType(),
NamedAttribute(builder.getIdentifier("index"), indexAttr));
return OpAsmParserResult(operand, builder.getAffineIntType(), attrs);
}
const char *DimOp::verify() const {
@ -283,7 +292,9 @@ const char *DimOp::verify() const {
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
*p << "] : " << *getMemRef()->getType();
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRef()->getType();
}
OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
@ -291,17 +302,18 @@ OpAsmParserResult LoadOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType *type;
SmallVector<SSAValue *, 4> operands;
SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimeter::SquareDelimeter) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperands(memrefInfo, type, operands) ||
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
return OpAsmParserResult(operands, type->getElementType());
return OpAsmParserResult(operands, type->getElementType(), attrs);
}
const char *LoadOp::verify() const {
@ -327,7 +339,9 @@ void StoreOp::print(OpAsmPrinter *p) const {
*p << "store " << *getValueToStore();
*p << ", " << *getMemRef() << '[';
p->printOperands(getIndices());
*p << "] : " << *getMemRef()->getType();
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRef()->getType();
}
OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
@ -336,12 +350,13 @@ OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<SSAValue *, 4> operands;
MemRefType *memrefType;
SmallVector<NamedAttribute, 4> attrs;
auto affineIntTy = parser->getBuilder().getAffineIntType();
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimeter::SquareDelimeter) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(memrefType) ||
parser->resolveOperands(storeValueInfo, memrefType->getElementType(),
operands) ||
@ -349,7 +364,7 @@ OpAsmParserResult StoreOp::parse(OpAsmParser *parser) {
parser->resolveOperands(indexInfo, affineIntTy, operands))
return {};
return OpAsmParserResult(operands, {});
return OpAsmParserResult(operands, {}, attrs);
}
const char *StoreOp::verify() const {

View File

@ -1667,11 +1667,31 @@ public:
return false;
}
bool parseAttribute(Attribute *&result, llvm::SMLoc *loc = nullptr) override {
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
bool parseAttribute(Attribute *&result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs,
llvm::SMLoc *loc = nullptr) override {
if (loc)
*loc = parser.getToken().getLoc();
result = parser.parseAttribute();
return result == nullptr;
if (!result)
return true;
attrs.push_back(
NamedAttribute(parser.builder.getIdentifier(attrName), result));
return false;
}
/// If a named attribute list is present, parse is into result.
bool parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
llvm::SMLoc *loc = nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return false;
if (loc)
*loc = parser.getToken().getLoc();
return parser.parseAttributeDict(result) == ParseFailure;
}
bool parseOperand(OperandType &result) override {
@ -1685,26 +1705,26 @@ public:
bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimeter delimeter = Delimeter::NoDelimeter) override {
Delimiter delimiter = Delimiter::None) override {
auto startLoc = parser.getToken().getLoc();
// Handle delimeters.
switch (delimeter) {
case Delimeter::NoDelimeter:
// Handle delimiters.
switch (delimiter) {
case Delimiter::None:
break;
case Delimeter::OptionalParenDelimeter:
case Delimiter::OptionalParen:
if (parser.getToken().isNot(Token::l_paren))
return false;
LLVM_FALLTHROUGH;
case Delimeter::ParenDelimeter:
case Delimiter::Paren:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true;
break;
case Delimeter::OptionalSquareDelimeter:
case Delimiter::OptionalSquare:
if (parser.getToken().isNot(Token::l_square))
return false;
LLVM_FALLTHROUGH;
case Delimeter::SquareDelimeter:
case Delimiter::Square:
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true;
break;
@ -1720,18 +1740,18 @@ public:
} while (parser.consumeIf(Token::comma));
}
// Handle delimeters. If we reach here, the optional delimiters were
// Handle delimiters. If we reach here, the optional delimiters were
// present, so we need to parse their closing one.
switch (delimeter) {
case Delimeter::NoDelimeter:
switch (delimiter) {
case Delimiter::None:
break;
case Delimeter::OptionalParenDelimeter:
case Delimeter::ParenDelimeter:
case Delimiter::OptionalParen:
case Delimiter::Paren:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true;
break;
case Delimeter::OptionalSquareDelimeter:
case Delimeter::SquareDelimeter:
case Delimiter::OptionalSquare:
case Delimiter::Square:
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true;
break;

View File

@ -45,6 +45,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32):
// CHECK: %c42_i32_0 = constant 42 : i32
%7 = constant 42 : i32
// CHECK: %c43 = constant 43 {crazy: "foo"} : affineint
%8 = constant 43 {crazy: "foo"} : affineint
return
}

View File

@ -177,20 +177,20 @@ bb42: // CHECK: bb0:
// CHECK: "foo"()
"foo"(){} : ()->()
// CHECK: "foo"(){a: 1, b: -423, c: [true, false], d: 1.600000e+01} : () -> ()
"foo"(){a: 1, b: -423, c: [true, false], d: 16.0 } : () -> ()
// CHECK: "foo"() {a: 1, b: -423, c: [true, false], d: 1.600000e+01} : () -> ()
"foo"() {a: 1, b: -423, c: [true, false], d: 16.0 } : () -> ()
// CHECK: "foo"(){map1: #map{{[0-9]+}}}
"foo"(){map1: #map1} : () -> ()
// CHECK: "foo"() {map1: #map{{[0-9]+}}}
"foo"() {map1: #map1} : () -> ()
// CHECK: "foo"(){map2: #map{{[0-9]+}}}
"foo"(){map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
// CHECK: "foo"() {map2: #map{{[0-9]+}}}
"foo"() {map2: (d0, d1, d2) -> (d0, d1, d2)} : () -> ()
// CHECK: "foo"(){map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
"foo"(){map12: [#map1, #map2]} : () -> ()
// CHECK: "foo"() {map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
"foo"() {map12: [#map1, #map2]} : () -> ()
// CHECK: "foo"(){cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
"foo"(){if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
// CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
"foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
return
}