[mlir] Initial support for type constraints in the declarative assembly format

Summary: This revision add support for accepting a few type constraints, e.g. AllTypesMatch, when inferring types for operands and results. This is used to remove the c++ parsers for several additional operations.

Differential Revision: https://reviews.llvm.org/D73735
This commit is contained in:
River Riddle 2020-02-03 21:52:38 -08:00
parent 8413116bf1
commit 7ef37a5f99
12 changed files with 210 additions and 208 deletions

View File

@ -444,7 +444,8 @@ def LLVM_ShuffleVectorOp
// Misc operations.
def LLVM_SelectOp
: LLVM_OneResultOp<"select", [NoSideEffect]>,
: LLVM_OneResultOp<"select",
[NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>,
Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue,
LLVM_Type:$falseValue)>,
LLVM_Builder<
@ -454,8 +455,7 @@ def LLVM_SelectOp
"Value rhs", [{
build(b, result, lhs.getType(), condition, lhs, rhs);
}]>];
let parser = [{ return parseSelectOp(parser, result); }];
let printer = [{ printSelectOp(p, *this); }];
let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)";
}
// Terminators.

View File

@ -99,7 +99,8 @@ def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> {
// -----
def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> {
def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert",
[NoSideEffect, AllTypesMatch<["base", "insert", "result"]>]> {
let summary = [{
Make a copy of an object, with a modified bit field that comes from
another object.
@ -163,6 +164,12 @@ def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> {
let results = (outs
SPV_ScalarOrVectorOf<SPV_Integer>:$result
);
let verifier = [{ return success(); }];
let assemblyFormat = [{
operands attr-dict `:` type($base) `,` type($offset) `,` type($count)
}];
}
// -----

View File

@ -794,7 +794,8 @@ def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []
// -----
def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> {
def SPV_SelectOp : SPV_Op<"Select",
[NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>]> {
let summary = [{
Select between two objects. Before version 1.4, results are only
computed per component.
@ -851,6 +852,10 @@ def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> {
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
Value cond, Value trueValue,
Value falseValue}]>];
let assemblyFormat = [{
operands attr-dict `:` type($condition) `,` type($result)
}];
}
// -----

View File

@ -407,10 +407,9 @@ def Vector_InsertOp :
Vector_Op<"insert", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>,
Results<(outs AnyVector)> {
Results<(outs AnyVector:$res)> {
let summary = "insert operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
@ -425,6 +424,10 @@ def Vector_InsertOp :
f32 into vector<4x8x16xf32>
```
}];
let assemblyFormat = [{
$source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value source, " #
"Value dest, ArrayRef<int64_t>">];
@ -497,11 +500,10 @@ def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [NoSideEffect,
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
Results<(outs AnyVector)> {
Results<(outs AnyVector:$res)> {
let summary = "strided_slice operation";
let description = [{
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
@ -522,6 +524,11 @@ def Vector_InsertStridedSliceOp :
vector<2x4xf32> into vector<16x4x8xf32>
```
}];
let assemblyFormat = [{
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value source, Value dest, " #
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides">];

View File

@ -1658,7 +1658,9 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
string description> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have same " # description,
AllMatchSameOperatorPred<names, operator>>;
AllMatchSameOperatorPred<names, operator>> {
list<string> values = names;
}
class AllElementCountsMatch<list<string> names> :
AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,

View File

@ -49,6 +49,9 @@ public:
Kind getKind() const { return kind; }
// Returns the Tablegen definition this operator was constructed from.
const llvm::Record &getDef() const { return *def; }
protected:
// The TableGen definition of this trait.
const llvm::Record *def;

View File

@ -781,40 +781,6 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::SelectOp.
//===----------------------------------------------------------------------===//
static void printSelectOp(OpAsmPrinter &p, SelectOp &op) {
p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue()
<< ", " << op.falseValue();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.condition().getType() << ", " << op.trueValue().getType();
}
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
// attribute-dict? `:` type, type
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType condition, trueValue, falseValue;
Type conditionType, argType;
if (parser.parseOperand(condition) || parser.parseComma() ||
parser.parseOperand(trueValue) || parser.parseComma() ||
parser.parseOperand(falseValue) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(conditionType) || parser.parseComma() ||
parser.parseType(argType))
return failure();
if (parser.resolveOperand(condition, conditionType, result.operands) ||
parser.resolveOperand(trueValue, argType, result.operands) ||
parser.resolveOperand(falseValue, argType, result.operands))
return failure();
result.addTypes(argType);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::BrOp.
//===----------------------------------------------------------------------===//

View File

@ -1065,54 +1065,6 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
results.insert<ConvertChainedBitcast>(context);
}
//===----------------------------------------------------------------------===//
// spv.BitFieldInsert
//===----------------------------------------------------------------------===//
static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
OperationState &state) {
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
Type baseType;
Type offsetType;
Type countType;
auto loc = parser.getCurrentLocation();
if (parser.parseOperandList(operandInfo, 4) || parser.parseColon() ||
parser.parseType(baseType) || parser.parseComma() ||
parser.parseType(offsetType) || parser.parseComma() ||
parser.parseType(countType) ||
parser.resolveOperands(operandInfo,
{baseType, baseType, offsetType, countType}, loc,
state.operands)) {
return failure();
}
state.addTypes(baseType);
return success();
}
static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
OpAsmPrinter &printer) {
printer << spirv::BitFieldInsertOp::getOperationName() << ' '
<< bitFieldInsertOp.getOperands() << " : "
<< bitFieldInsertOp.base().getType() << ", "
<< bitFieldInsertOp.offset().getType() << ", "
<< bitFieldInsertOp.count().getType();
}
static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) {
auto baseType = bitFieldOp.base().getType();
auto insertType = bitFieldOp.insert().getType();
auto resultType = bitFieldOp.getResult().getType();
if ((baseType != insertType) || (baseType != resultType)) {
return bitFieldOp.emitError("expected the same type for the base operand, "
"insert operand, and "
"result, but provided ")
<< baseType << ", " << insertType << " and " << resultType;
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//
@ -2522,42 +2474,9 @@ void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
}
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
OpAsmParser::OperandType condition;
SmallVector<OpAsmParser::OperandType, 2> operands;
SmallVector<Type, 2> types;
auto loc = parser.getCurrentLocation();
if (parser.parseOperand(condition) || parser.parseComma() ||
parser.parseOperandList(operands, 2) ||
parser.parseColonTypeList(types)) {
return failure();
}
if (types.size() != 2) {
return parser.emitError(
loc, "need exactly two trailing types for select condition and object");
}
if (parser.resolveOperand(condition, types[0], state.operands) ||
parser.resolveOperands(operands, types[1], state.operands)) {
return failure();
}
return parser.addTypesToList(types[1], state.types);
}
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
<< " : " << op.condition().getType() << ", " << op.result().getType();
}
static LogicalResult verify(spirv::SelectOp op) {
auto resultTy = op.result().getType();
if (op.true_value().getType() != resultTy) {
return op.emitOpError("result type and true value type must be the same");
}
if (op.false_value().getType() != resultTy) {
return op.emitOpError("result type and false value type must be the same");
}
if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
auto resultVectorTy = resultTy.dyn_cast<VectorType>();
auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
if (!resultVectorTy) {
return op.emitOpError("result expected to be of vector type when "
"condition is of vector type");

View File

@ -700,31 +700,6 @@ void InsertOp::build(Builder *builder, OperationState &result, Value source,
result.addAttribute(getPositionAttrName(), positionAttr);
}
static void print(OpAsmPrinter &p, InsertOp op) {
p << op.getOperationName() << " " << op.source() << ", " << op.dest()
<< op.position();
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
}
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType source, dest;
Type sourceType;
VectorType destType;
Attribute attr;
return failure(parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) ||
parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(attrs) ||
parser.parseColonType(sourceType) ||
parser.parseKeywordType("into", destType) ||
parser.resolveOperand(source, sourceType, result.operands) ||
parser.resolveOperand(dest, destType, result.operands) ||
parser.addTypeToList(destType, result.types));
}
static LogicalResult verify(InsertOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
@ -793,27 +768,6 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
result.addAttribute(getStridesAttrName(), stridesAttr);
}
static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
}
static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType source, dest;
VectorType sourceVectorType, destVectorType;
return failure(
parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(sourceVectorType) ||
parser.parseKeywordType("into", destVectorType) ||
parser.resolveOperand(source, sourceVectorType, result.operands) ||
parser.resolveOperand(dest, destVectorType, result.operands) ||
parser.addTypeToList(destVectorType, result.types));
}
// TODO(ntv) Should be moved to Tablegen Confined attributes.
template <typename OpType>
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,

View File

@ -227,7 +227,7 @@ func @bit_field_insert_vec(%base: vector<3xi32>, %insert: vector<3xi32>, %offset
// -----
func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> {
// expected-error @+1 {{expected the same type for the base operand, insert operand, and result, but provided 'vector<3xi32>', 'vector<2xi32>' and 'vector<3xi32>'}}
// expected-error @+1 {{all of {base, insert, result} have same type}}
%0 = "spv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32>
spv.ReturnValue %0 : vector<3xi32>
}
@ -856,7 +856,7 @@ func @select_op_vec_condn_vec(%arg0: vector<3xi1>) -> () {
func @select_op(%arg0: i1) -> () {
%0 = spv.constant 2 : i32
%1 = spv.constant 3 : i32
// expected-error @+1 {{need exactly two trailing types for select condition and object}}
// expected-error @+2 {{expected ','}}
%2 = spv.Select %arg0, %0, %1 : i1
return
}
@ -886,7 +886,7 @@ func @select_op(%arg1: vector<4xi1>) -> () {
func @select_op(%arg1: vector<4xi1>) -> () {
%0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
%1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
// expected-error @+1 {{op result type and true value type must be the same}}
// expected-error @+1 {{all of {true_value, false_value, result} have same type}}
%2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32>
return
}
@ -896,7 +896,7 @@ func @select_op(%arg1: vector<4xi1>) -> () {
func @select_op(%arg1: vector<4xi1>) -> () {
%0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
%1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
// expected-error @+1 {{op result type and false value type must be the same}}
// expected-error @+1 {{all of {true_value, false_value, result} have same type}}
%2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32>
return
}

View File

@ -7,7 +7,8 @@ include "mlir/IR/OpBase.td"
def TestDialect : Dialect {
let name = "test";
}
class TestFormat_Op<string name, string fmt> : Op<TestDialect, name> {
class TestFormat_Op<string name, string fmt, list<OpTrait> traits = []>
: Op<TestDialect, name, traits> {
let assemblyFormat = fmt;
}
@ -234,3 +235,24 @@ def ZCoverageValidC : TestFormat_Op<"variable_valid_c", [{
operands functional-type(operands, results) attr-dict
}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
// Check that we can infer type equalities from certain traits.
def ZCoverageValidD : TestFormat_Op<"variable_valid_d", [{
operands type($result) attr-dict
}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
Results<(outs AnyMemRef:$result)>;
def ZCoverageValidE : TestFormat_Op<"variable_valid_e", [{
$operand type($operand) attr-dict
}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
Results<(outs AnyMemRef:$result)>;
def ZCoverageValidF : TestFormat_Op<"variable_valid_f", [{
operands type($other) attr-dict
}], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
def ZCoverageValidG : TestFormat_Op<"variable_valid_g", [{
operands type($other) attr-dict
}], [AllTypesMatch<["operand", "other"]>]>,
Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
def ZCoverageValidH : TestFormat_Op<"variable_valid_h", [{
operands type($result) attr-dict
}], [AllTypesMatch<["operand", "result"]>]>,
Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;

View File

@ -202,10 +202,32 @@ bool LiteralElement::isValidLiteral(StringRef value) {
namespace {
struct OperationFormat {
/// This class represents a specific resolver for an operand or result type.
class TypeResolution {
public:
TypeResolution() = default;
/// Get the index into the buildable types for this type, or None.
Optional<int> getBuilderIdx() const { return builderIdx; }
void setBuilderIdx(int idx) { builderIdx = idx; }
/// Get the variable this type is resolved to, or None.
Optional<StringRef> getVariable() const { return variableName; }
void setVariable(StringRef variable) { variableName = variable; }
private:
/// If the type is resolved with a buildable type, this is the index into
/// 'buildableTypes' in the parent format.
Optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
/// the name of the variable that this type is resolved to.
Optional<StringRef> variableName;
};
OperationFormat(const Operator &op)
: allOperandTypes(false), allResultTypes(false) {
buildableOperandTypes.resize(op.getNumOperands(), llvm::None);
buildableResultTypes.resize(op.getNumResults(), llvm::None);
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
}
/// Generate the operation parser from this format.
@ -228,7 +250,7 @@ struct OperationFormat {
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
/// The index of the buildable type, if valid, for every operand and result.
std::vector<Optional<int>> buildableOperandTypes, buildableResultTypes;
std::vector<TypeResolution> operandTypes, resultTypes;
};
} // end anonymous namespace
@ -398,8 +420,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
} else {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
body << " result.addTypes(";
if (Optional<int> val = buildableResultTypes[i])
if (Optional<int> val = resultTypes[i].getBuilderIdx())
body << "odsBuildableType" << *val;
else if (Optional<StringRef> var = resultTypes[i].getVariable())
body << *var << "Types";
else
body << op.getResultName(i) << "Types";
body << ");\n";
@ -450,8 +474,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
if (op.getNumOperands() > 1) {
body << "llvm::concat<const Type>(";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (Optional<int> val = buildableOperandTypes[i])
if (Optional<int> val = operandTypes[i].getBuilderIdx())
body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
else if (Optional<StringRef> var = operandTypes[i].getVariable())
body << *var << "Types";
else
body << op.getOperand(i).name << "Types";
});
@ -470,8 +496,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
NamedTypeConstraint &operand = op.getOperand(i);
body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
if (Optional<int> val = buildableOperandTypes[i])
if (Optional<int> val = operandTypes[i].getBuilderIdx())
body << "odsBuildableType" << *val << ", ";
else if (Optional<StringRef> var = operandTypes[i].getVariable())
body << *var << "Types, " << operand.name << "OperandsLoc, ";
else
body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
body << "result.operands))\n return failure();\n";
@ -803,6 +831,13 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
// FormatParser
//===----------------------------------------------------------------------===//
/// Function to find an element within the given range that has the same name as
/// 'name'.
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
namespace {
/// This class implements a parser for an instance of an operation assembly
/// format.
@ -817,6 +852,18 @@ public:
LogicalResult parse();
private:
/// Given the values of an `AllTypesMatch` trait, check for inferrable type
/// resolution.
void handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
/// Check for inferrable type resolution given all operands, and or results,
/// have the same type. If 'includeResults' is true, the results also have the
/// same type as all of the operands.
void handleSameTypesConstraint(
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
bool includeResults);
/// Parse a specific element.
LogicalResult parseElement(std::unique_ptr<Element> &element,
bool isTopLevel);
@ -870,8 +917,8 @@ private:
OperationFormat &fmt;
Operator &op;
// The following are various bits of format state used for verification during
// parsing.
// The following are various bits of format state used for verification
// during parsing.
bool hasAllOperands = false, hasAttrDict = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
@ -894,6 +941,19 @@ LogicalResult FormatParser::parse() {
if (!hasAttrDict)
return emitError(loc, "format missing 'attr-dict' directive");
// Check for any type traits that we can use for inferring types.
llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
for (const OpTrait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch"))
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
variableTyResolver);
else if (def.getName() == "SameTypeOperands")
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
else if (def.getName() == "SameOperandsAndResultType")
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
}
// Check that all of the result types can be inferred.
auto &buildableTypes = fmt.buildableTypes;
if (!fmt.allResultTypes) {
@ -901,6 +961,13 @@ LogicalResult FormatParser::parse() {
if (seenResultTypes.test(i))
continue;
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
fmt.resultTypes[i].setVariable(varResolverIt->second->name);
continue;
}
// If the result is not variadic, allow for the case where the type has a
// builder that we can use.
NamedTypeConstraint &result = op.getResult(i);
@ -911,7 +978,7 @@ LogicalResult FormatParser::parse() {
}
// Note in the format that this result uses the custom builder.
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.buildableResultTypes[i] = it.first->second;
fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
}
@ -927,21 +994,78 @@ LogicalResult FormatParser::parse() {
}
// Check that the operand type is in the format, or that it can be inferred.
if (!fmt.allOperandTypes && !seenOperandTypes.test(i)) {
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
if (!builder || (hasAllOperands && operand.isVariadic())) {
return emitError(loc, "format missing instance of operand #" +
Twine(i) + "('" + operand.name + "') type");
}
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.buildableOperandTypes[i] = it.first->second;
if (fmt.allOperandTypes || seenOperandTypes.test(i))
continue;
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
if (varResolverIt != variableTyResolver.end()) {
fmt.operandTypes[i].setVariable(varResolverIt->second->name);
continue;
}
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
if (!builder || (hasAllOperands && operand.isVariadic())) {
return emitError(loc, "format missing instance of operand #" + Twine(i) +
"('" + operand.name + "') type");
}
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second);
}
return success();
}
void FormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
const NamedTypeConstraint *arg = nullptr;
if ((arg = findArg(op.getOperands(), values[i]))) {
if (!seenOperandTypes.test(arg - op.operand_begin()))
continue;
} else if ((arg = findArg(op.getResults(), values[i]))) {
if (!seenResultTypes.test(arg - op.result_begin()))
continue;
} else {
continue;
}
// Mark this value as the type resolver for the other variables.
for (unsigned j = 0; j != i; ++j)
variableTyResolver[values[j]] = arg;
for (unsigned j = i + 1; j != e; ++j)
variableTyResolver[values[j]] = arg;
}
}
void FormatParser::handleSameTypesConstraint(
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
bool includeResults) {
const NamedTypeConstraint *resolver = nullptr;
int resolvedIt = -1;
// Check to see if there is an operand or result to use for the resolution.
if ((resolvedIt = seenOperandTypes.find_first()) != -1)
resolver = &op.getOperand(resolvedIt);
else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
resolver = &op.getResult(resolvedIt);
else
return;
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
variableTyResolver[op.getOperand(i).name] = resolver;
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
variableTyResolver[op.getResultName(i)] = resolver;
}
}
LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
bool isTopLevel) {
// Directives.
@ -965,23 +1089,16 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
StringRef name = varTok.getSpelling().drop_front();
llvm::SMLoc loc = varTok.getLoc();
// Functor used to find an element within the given range that has the same
// name as 'name'.
auto findArg = [&](auto &&range) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
};
// Check that the parsed argument is something actually registered on the op.
/// Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes())) {
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
if (isTopLevel && !seenAttrs.insert(attr).second)
return emitError(loc, "attribute '" + name + "' is already bound");
element = std::make_unique<AttributeVariable>(attr);
return success();
}
/// Operands
if (const NamedTypeConstraint *operand = findArg(op.getOperands())) {
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
if (isTopLevel) {
if (hasAllOperands || !seenOperands.insert(operand).second)
return emitError(loc, "operand '" + name + "' is already bound");
@ -990,7 +1107,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
return success();
}
/// Results.
if (const NamedTypeConstraint *result = findArg(op.getResults())) {
if (const auto *result = findArg(op.getResults(), name)) {
if (isTopLevel)
return emitError(loc, "results can not be used at the top level");
element = std::make_unique<ResultVariable>(result);