[mlir][DeclarativeParser] Add support for the TypesMatchWith trait.

This allows for injecting type constraints that are not direct 1-1 mappings, for example when one type is equal to the element type of another. This allows for moving over several more parsers to the declarative form.

Differential Revision: https://reviews.llvm.org/D74648
This commit is contained in:
River Riddle 2020-02-21 13:19:03 -08:00
parent 393f4e8ac2
commit 26222db01b
8 changed files with 162 additions and 343 deletions

View File

@ -357,6 +357,8 @@ def CallIndirectOp : Std_Op<"call_indirect", [
let verifier = ?;
let hasCanonicalizer = 1;
let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
}
def CeilFOp : FloatUnaryOp<"ceilf"> {
@ -490,6 +492,8 @@ def CmpIOp : Std_Op<"cmpi",
let verifier = [{ return success(); }];
let hasFolder = 1;
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
@ -761,6 +765,10 @@ def ExtractElementOp : Std_Op<"extract_element",
}];
let hasFolder = 1;
let assemblyFormat = [{
$aggregate `[` $indices `]` attr-dict `:` type($aggregate)
}];
}
def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
@ -853,6 +861,8 @@ def LoadOp : Std_Op<"load",
}];
let hasFolder = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
}
def LogOp : FloatUnaryOp<"log"> {
@ -1090,6 +1100,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
}];
let hasFolder = 1;
let assemblyFormat = [{
$condition `,` $true_value `,` $false_value attr-dict `:` type($result)
}];
}
def SignExtendIOp : Std_Op<"sexti",
@ -1222,6 +1236,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
[{ build(builder, result, aggregateType, element); }]>];
let hasFolder = 1;
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
}
def StoreOp : Std_Op<"store",
@ -1264,6 +1280,10 @@ def StoreOp : Std_Op<"store",
}];
let hasFolder = 1;
let assemblyFormat = [{
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
}];
}
def SubFOp : FloatArithmeticOp<"subf"> {
@ -1517,11 +1537,12 @@ def TensorLoadOp : Std_Op<"tensor_load",
result.addTypes(resultType);
}]>];
let extraClassDeclaration = [{
/// The result of a tensor_load is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
def TensorStoreOp : Std_Op<"tensor_store",
@ -1545,6 +1566,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {

View File

@ -363,6 +363,10 @@ def Vector_ExtractElementOp :
return vector().getType().cast<VectorType>();
}
}];
let assemblyFormat = [{
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
}];
}
def Vector_ExtractOp :
@ -512,6 +516,11 @@ def Vector_InsertElementOp :
return dest().getType().cast<VectorType>();
}
}];
let assemblyFormat = [{
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
type($result)
}];
}
def Vector_InsertOp :

View File

@ -496,6 +496,12 @@ public:
return failure();
return success();
}
template <typename Operands>
ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
SmallVectorImpl<Value> &result) {
return resolveOperands(std::forward<Operands>(operands),
ArrayRef<Type>(type), loc, result);
}
template <typename Operands, typename Types>
ParseResult resolveOperands(Operands &&operands, Types &&types,
llvm::SMLoc loc, SmallVectorImpl<Value> &result) {

View File

@ -294,6 +294,11 @@ public:
void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
addTypes(RangeT &&newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
/// Add an attribute with the specified name.
void addAttribute(StringRef name, Attribute attr) {

View File

@ -505,29 +505,6 @@ struct SimplifyIndirectCallWithKnownCallee
};
} // end anonymous namespace.
static ParseResult parseCallIndirectOp(OpAsmParser &parser,
OperationState &result) {
FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
return failure(
parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(calleeType) ||
parser.resolveOperand(callee, calleeType, result.operands) ||
parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result.operands) ||
parser.addTypesToList(calleeType.getResults(), result.types));
}
static void print(OpAsmPrinter &p, CallIndirectOp op) {
p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
p << " : " << op.getCallee().getType();
}
void CallIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
@ -570,55 +547,6 @@ static void buildCmpIOp(Builder *build, OperationState &result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
SmallVector<NamedAttribute, 4> attrs;
Attribute predicateNameAttr;
Type type;
if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
attrs) ||
parser.parseComma() || parser.parseOperandList(ops, 2) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands))
return failure();
if (!predicateNameAttr.isa<StringAttr>())
return parser.emitError(parser.getNameLoc(),
"expected string comparison predicate attribute");
// Rewrite string attribute to an enum value.
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
if (!predicate.hasValue())
return parser.emitError(parser.getNameLoc())
<< "unknown comparison predicate \"" << predicateName << "\"";
auto builder = parser.getBuilder();
Type i1Type = getCheckedI1SameShape(type);
if (!i1Type)
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
result.attributes = attrs;
result.addTypes({i1Type});
return success();
}
static void print(OpAsmPrinter &p, CmpIOp op) {
p << "cmpi ";
Builder b(op.getContext());
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
<< '"' << ", " << op.lhs() << ", " << op.rhs();
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
p << " : " << op.lhs().getType();
}
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
@ -1486,30 +1414,6 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
// ExtractElementOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
p << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getAggregate().getType();
}
static ParseResult parseExtractElementOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ShapedType type;
auto indexTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(aggregateInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(aggregateInfo, type, result.operands) ||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types));
}
static LogicalResult verify(ExtractElementOp op) {
// Verify the # indices match if we have a ranked type.
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
@ -1577,28 +1481,6 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
// LoadOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, LoadOp op) {
p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType type;
auto indexTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(memrefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types));
}
static LogicalResult verify(LoadOp op) {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
return op.emitOpError("incorrect number of indices for load");
@ -1902,31 +1784,6 @@ bool SIToFPOp::areCastCompatible(Type a, Type b) {
// SelectOp
//===----------------------------------------------------------------------===//
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<NamedAttribute, 4> attrs;
Type type;
if (parser.parseOperandList(ops, 3) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return failure();
auto i1Type = getCheckedI1SameShape(type);
if (!i1Type)
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
std::array<Type, 3> types = {i1Type, type, type};
return failure(parser.resolveOperands(ops, types, parser.getNameLoc(),
result.operands) ||
parser.addTypeToList(type, result.types));
}
static void print(OpAsmPrinter &p, SelectOp op) {
p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
p.printOptionalAttrDict(op.getAttrs());
}
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
auto condition = getCondition();
@ -1968,25 +1825,6 @@ static LogicalResult verify(SignExtendIOp op) {
// SplatOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, SplatOp op) {
p << "splat " << op.getOperand();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType splatValueInfo;
ShapedType shapedType;
return failure(parser.parseOperand(splatValueInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(shapedType) ||
parser.resolveOperand(splatValueInfo,
shapedType.getElementType(),
result.operands) ||
parser.addTypeToList(shapedType, result.types));
}
static LogicalResult verify(SplatOp op) {
// TODO: we could replace this by a trait.
if (op.getOperand().getType() !=
@ -2017,32 +1855,6 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
// StoreOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, StoreOp op) {
p << "store " << op.getValueToStore();
p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getMemRefType();
}
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType memrefType;
auto indexTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
result.operands) ||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
parser.resolveOperands(indexInfo, indexTy, result.operands));
}
static LogicalResult verify(StoreOp op) {
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
return op.emitOpError("store index operand count not equal to memref rank");
@ -2156,51 +1968,6 @@ static Type getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TensorLoadOp op) {
p << "tensor_load " << op.getOperand();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand().getType();
}
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType op;
Type type;
return failure(
parser.parseOperand(op) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(op, type, result.operands) ||
parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
}
//===----------------------------------------------------------------------===//
// TensorStoreOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, TensorStoreOp op) {
p << "tensor_store " << op.tensor() << ", " << op.memref();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.memref().getType();
}
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(
parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
loc, result.operands));
}
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//

View File

@ -412,31 +412,6 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
// ExtractElementOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
p << op.getOperationName() << " " << op.vector() << "[" << op.position()
<< " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.vector().getType();
}
static ParseResult parseExtractElementOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType vector, position;
Type positionType;
VectorType vectorType;
if (parser.parseOperand(vector) || parser.parseLSquare() ||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
parser.parseRSquare() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(vectorType))
return failure();
Type resultType = vectorType.getElementType();
return failure(
parser.resolveOperand(vector, vectorType, result.operands) ||
parser.resolveOperand(position, positionType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static LogicalResult verify(vector::ExtractElementOp op) {
VectorType vectorType = op.getVectorType();
if (vectorType.getRank() != 1)
@ -715,33 +690,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
// InsertElementOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, InsertElementOp op) {
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
<< op.position() << " : " << op.position().getType() << "]";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.dest().getType();
}
static ParseResult parseInsertElementOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType source, dest, position;
Type positionType;
VectorType destType;
if (parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) || parser.parseLSquare() ||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
parser.parseRSquare() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(destType))
return failure();
Type sourceType = destType.getElementType();
return failure(
parser.resolveOperand(source, sourceType, result.operands) ||
parser.resolveOperand(dest, destType, result.operands) ||
parser.resolveOperand(position, positionType, result.operands) ||
parser.addTypeToList(destType, result.types));
}
static LogicalResult verify(InsertElementOp op) {
auto dstVectorType = op.getDestVectorType();
if (dstVectorType.getRank() != 1)

View File

@ -226,7 +226,7 @@ func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
%r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}}
%r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be integer-like, but got 'f32'}}
}
// -----
@ -298,13 +298,13 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
// -----
func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
// expected-error@+1 {{expected type with valid i1 shape}}
// expected-error@+1 {{'result' must be integer-like or floating-point-like, but got '() -> ()'}}
%sel = select %cond, %idx, %idx : () -> ()
// -----
func @invalid_cmp_shape(%idx : () -> ()) {
// expected-error@+1 {{expected type with valid i1 shape}}
// expected-error@+1 {{'lhs' must be integer-like, but got '() -> ()'}}
%cmp = cmpi "eq", %idx, %idx : () -> ()
// -----
@ -340,7 +340,7 @@ func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
// -----
func @invalid_cmp_attr(%idx : i32) {
// expected-error@+1 {{expected string comparison predicate attribute}}
// expected-error@+1 {{invalid kind of attribute specified}}
%cmp = cmpi i1, %idx, %idx : i32
// -----

View File

@ -219,16 +219,26 @@ struct OperationFormat {
void setBuilderIdx(int idx) { builderIdx = idx; }
/// Get the variable this type is resolved to, or None.
Optional<StringRef> getVariable() const { return variableName; }
void setVariable(StringRef variable) { variableName = variable; }
const NamedTypeConstraint *getVariable() const { return variable; }
Optional<StringRef> getVarTransformer() const {
return variableTransformer;
}
void setVariable(const NamedTypeConstraint *var,
Optional<StringRef> transformer) {
variable = var;
variableTransformer = transformer;
}
private:
/// If the type is resolved with a buildable type, this is the index into
/// 'buildableTypes' in the parent format.
Optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
/// the name of the variable that this type is resolved to.
Optional<StringRef> variableName;
/// the variable that this type is resolved to.
const NamedTypeConstraint *variable;
/// If the type is resolved based upon another operand or result, this is
/// a transformer to apply to the variable when resolving.
Optional<StringRef> variableTransformer;
};
OperationFormat(const Operator &op)
@ -487,6 +497,34 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
void OperationFormat::genParserTypeResolution(Operator &op,
OpMethodBody &body) {
// If any of type resolutions use transformed variables, make sure that the
// types of those variables are resolved.
SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
FmtContext verifierFCtx;
for (TypeResolution &resolver :
llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
Optional<StringRef> transformer = resolver.getVarTransformer();
if (!transformer)
continue;
// Ensure that we don't verify the same variables twice.
const NamedTypeConstraint *variable = resolver.getVariable();
if (!verifiedVariables.insert(variable).second)
continue;
auto constraint = variable->constraint;
body << " for (Type type : " << variable->name << "Types) {\n"
<< " (void)type;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
&verifierFCtx.withSelf("type"))
<< ")) {\n"
<< formatv(" return parser.emitError(parser.getNameLoc()) << "
"\"'{0}' must be {1}, but got \" << type;\n",
variable->name, constraint.getDescription())
<< " }\n"
<< " }\n";
}
// Initialize the set of buildable types.
if (!buildableTypes.empty()) {
body << " Builder &builder = parser.getBuilder();\n";
@ -498,18 +536,27 @@ void OperationFormat::genParserTypeResolution(Operator &op,
<< tgfmt(it.first, &typeBuilderCtx) << ";\n";
}
// Emit the code necessary for a type resolver.
auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
if (Optional<int> val = resolver.getBuilderIdx()) {
body << "odsBuildableType" << *val;
} else if (const NamedTypeConstraint *var = resolver.getVariable()) {
if (Optional<StringRef> tform = resolver.getVarTransformer())
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
else
body << var->name << "Types";
} else {
body << curVar << "Types";
}
};
// Resolve each of the result types.
if (allResultTypes) {
body << " result.addTypes(allResultTypes);\n";
} else {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
body << " result.addTypes(";
if (Optional<int> val = resultTypes[i].getBuilderIdx())
body << "odsBuildableType" << *val;
else if (Optional<StringRef> var = resultTypes[i].getVariable())
body << *var << "Types";
else
body << op.getResultName(i) << "Types";
emitTypeResolver(resultTypes[i], op.getResultName(i));
body << ");\n";
}
}
@ -552,25 +599,19 @@ void OperationFormat::genParserTypeResolution(Operator &op,
if (hasAllOperands) {
body << " if (parser.resolveOperands(allOperands, ";
auto emitOperandType = [&](int idx) {
if (Optional<int> val = operandTypes[idx].getBuilderIdx())
body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
else if (Optional<StringRef> var = operandTypes[idx].getVariable())
body << *var << "Types";
else
body << op.getOperand(idx).name << "Types";
};
// Group all of the operand types together to perform the resolution all at
// once. Use llvm::concat to perform the merge. llvm::concat does not allow
// the case of a single range, so guard it here.
if (op.getNumOperands() > 1) {
body << "llvm::concat<const Type>(";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body,
emitOperandType);
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
body << "ArrayRef<Type>(";
emitTypeResolver(operandTypes[i], op.getOperand(i).name);
body << ")";
});
body << ")";
} else {
emitOperandType(/*idx=*/0);
emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
}
body << ", allOperandLoc, result.operands))\n"
@ -583,13 +624,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
NamedTypeConstraint &operand = op.getOperand(i);
body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
if (Optional<int> val = operandTypes[i].getBuilderIdx())
body << "odsBuildableType" << *val << ", ";
else if (Optional<StringRef> var = operandTypes[i].getVariable())
body << *var << "Types, " << operand.name << "OperandsLoc, ";
else
body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
body << "result.operands))\n return failure();\n";
emitTypeResolver(operandTypes[i], operand.name);
// If this isn't a buildable type, verify the sizes match by adding the loc.
if (!operandTypes[i].getBuilderIdx())
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return failure();\n";
}
}
@ -954,18 +994,30 @@ public:
LogicalResult parse();
private:
/// This struct represents a type resolution instance. It includes a specific
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
struct TypeResolutionInstance {
const NamedTypeConstraint *type;
Optional<StringRef> transformer;
};
/// Given the values of an `AllTypesMatch` trait, check for inferrable type
/// resolution.
void handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
/// Check for inferrable type resolution given all operands, and or results,
/// have the same type. If 'includeResults' is true, the results also have the
/// same type as all of the operands.
void handleSameTypesConstraint(
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults);
/// Returns an argument with the given name that has been seen within the
/// format.
const NamedTypeConstraint *findSeenArg(StringRef name);
/// Parse a specific element.
LogicalResult parseElement(std::unique_ptr<Element> &element,
bool isTopLevel);
@ -1044,16 +1096,21 @@ LogicalResult FormatParser::parse() {
return emitError(loc, "format missing 'attr-dict' directive");
// Check for any type traits that we can use for inferring types.
llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
for (const OpTrait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch"))
if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
variableTyResolver);
else if (def.getName() == "SameTypeOperands")
} else if (def.getName() == "SameTypeOperands") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
else if (def.getName() == "SameOperandsAndResultType")
} else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
variableTyResolver[def.getValueAsString("rhs")] = {
lhsArg, def.getValueAsString("transformer")};
}
}
// Check that all of the result types can be inferred.
@ -1066,7 +1123,8 @@ LogicalResult FormatParser::parse() {
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
fmt.resultTypes[i].setVariable(varResolverIt->second->name);
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
continue;
}
@ -1102,7 +1160,8 @@ LogicalResult FormatParser::parse() {
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
if (varResolverIt != variableTyResolver.end()) {
fmt.operandTypes[i].setVariable(varResolverIt->second->name);
fmt.operandTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
continue;
}
@ -1121,30 +1180,23 @@ LogicalResult FormatParser::parse() {
void FormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
const NamedTypeConstraint *arg = nullptr;
if ((arg = findArg(op.getOperands(), values[i]))) {
if (!seenOperandTypes.test(arg - op.operand_begin()))
continue;
} else if ((arg = findArg(op.getResults(), values[i]))) {
if (!seenResultTypes.test(arg - op.result_begin()))
continue;
} else {
const NamedTypeConstraint *arg = findSeenArg(values[i]);
if (!arg)
continue;
}
// Mark this value as the type resolver for the other variables.
for (unsigned j = 0; j != i; ++j)
variableTyResolver[values[j]] = arg;
variableTyResolver[values[j]] = {arg, llvm::None};
for (unsigned j = i + 1; j != e; ++j)
variableTyResolver[values[j]] = arg;
variableTyResolver[values[j]] = {arg, llvm::None};
}
}
void FormatParser::handleSameTypesConstraint(
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults) {
const NamedTypeConstraint *resolver = nullptr;
int resolvedIt = -1;
@ -1160,14 +1212,22 @@ void FormatParser::handleSameTypesConstraint(
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
variableTyResolver[op.getOperand(i).name] = resolver;
variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
variableTyResolver[op.getResultName(i)] = resolver;
variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
}
}
const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
if (auto *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
if (auto *arg = findArg(op.getResults(), name))
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
return nullptr;
}
LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
bool isTopLevel) {
// Directives.
@ -1191,7 +1251,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
StringRef name = varTok.getSpelling().drop_front();
llvm::SMLoc loc = varTok.getLoc();
// Check that the parsed argument is something actually registered on the op.
// Check that the parsed argument is something actually registered on the
// op.
/// Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
if (isTopLevel && !seenAttrs.insert(attr).second)