forked from OSchip/llvm-project
[mlir] Initial support for type constraints in the declarative assembly format
Summary: This revision add support for accepting a few type constraints, e.g. AllTypesMatch, when inferring types for operands and results. This is used to remove the c++ parsers for several additional operations. Differential Revision: https://reviews.llvm.org/D73735
This commit is contained in:
parent
8413116bf1
commit
7ef37a5f99
|
@ -444,7 +444,8 @@ def LLVM_ShuffleVectorOp
|
|||
|
||||
// Misc operations.
|
||||
def LLVM_SelectOp
|
||||
: LLVM_OneResultOp<"select", [NoSideEffect]>,
|
||||
: LLVM_OneResultOp<"select",
|
||||
[NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>,
|
||||
Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue,
|
||||
LLVM_Type:$falseValue)>,
|
||||
LLVM_Builder<
|
||||
|
@ -454,8 +455,7 @@ def LLVM_SelectOp
|
|||
"Value rhs", [{
|
||||
build(b, result, lhs.getType(), condition, lhs, rhs);
|
||||
}]>];
|
||||
let parser = [{ return parseSelectOp(parser, result); }];
|
||||
let printer = [{ printSelectOp(p, *this); }];
|
||||
let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)";
|
||||
}
|
||||
|
||||
// Terminators.
|
||||
|
|
|
@ -99,7 +99,8 @@ def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> {
|
||||
def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert",
|
||||
[NoSideEffect, AllTypesMatch<["base", "insert", "result"]>]> {
|
||||
let summary = [{
|
||||
Make a copy of an object, with a modified bit field that comes from
|
||||
another object.
|
||||
|
@ -163,6 +164,12 @@ def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> {
|
|||
let results = (outs
|
||||
SPV_ScalarOrVectorOf<SPV_Integer>:$result
|
||||
);
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` type($base) `,` type($offset) `,` type($count)
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -794,7 +794,8 @@ def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> {
|
||||
def SPV_SelectOp : SPV_Op<"Select",
|
||||
[NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>]> {
|
||||
let summary = [{
|
||||
Select between two objects. Before version 1.4, results are only
|
||||
computed per component.
|
||||
|
@ -851,6 +852,10 @@ def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> {
|
|||
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
|
||||
Value cond, Value trueValue,
|
||||
Value falseValue}]>];
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` type($condition) `,` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -407,10 +407,9 @@ def Vector_InsertOp :
|
|||
Vector_Op<"insert", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
AllTypesMatch<["dest", "res"]>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>,
|
||||
Results<(outs AnyVector)> {
|
||||
Results<(outs AnyVector:$res)> {
|
||||
let summary = "insert operation";
|
||||
let description = [{
|
||||
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
|
||||
|
@ -425,6 +424,10 @@ def Vector_InsertOp :
|
|||
f32 into vector<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value source, " #
|
||||
"Value dest, ArrayRef<int64_t>">];
|
||||
|
@ -497,11 +500,10 @@ def Vector_InsertStridedSliceOp :
|
|||
Vector_Op<"insert_strided_slice", [NoSideEffect,
|
||||
PredOpTrait<"operand #0 and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
AllTypesMatch<["dest", "res"]>]>,
|
||||
Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
|
||||
I64ArrayAttr:$strides)>,
|
||||
Results<(outs AnyVector)> {
|
||||
Results<(outs AnyVector:$res)> {
|
||||
let summary = "strided_slice operation";
|
||||
let description = [{
|
||||
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
|
||||
|
@ -522,6 +524,11 @@ def Vector_InsertStridedSliceOp :
|
|||
vector<2x4xf32> into vector<16x4x8xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value source, Value dest, " #
|
||||
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides">];
|
||||
|
|
|
@ -1658,7 +1658,9 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
|
|||
string description> :
|
||||
PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have same " # description,
|
||||
AllMatchSameOperatorPred<names, operator>>;
|
||||
AllMatchSameOperatorPred<names, operator>> {
|
||||
list<string> values = names;
|
||||
}
|
||||
|
||||
class AllElementCountsMatch<list<string> names> :
|
||||
AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
|
||||
|
|
|
@ -49,6 +49,9 @@ public:
|
|||
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
// Returns the Tablegen definition this operator was constructed from.
|
||||
const llvm::Record &getDef() const { return *def; }
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this trait.
|
||||
const llvm::Record *def;
|
||||
|
|
|
@ -781,40 +781,6 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::SelectOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printSelectOp(OpAsmPrinter &p, SelectOp &op) {
|
||||
p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue()
|
||||
<< ", " << op.falseValue();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.condition().getType() << ", " << op.trueValue().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
|
||||
// attribute-dict? `:` type, type
|
||||
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType condition, trueValue, falseValue;
|
||||
Type conditionType, argType;
|
||||
|
||||
if (parser.parseOperand(condition) || parser.parseComma() ||
|
||||
parser.parseOperand(trueValue) || parser.parseComma() ||
|
||||
parser.parseOperand(falseValue) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(conditionType) || parser.parseComma() ||
|
||||
parser.parseType(argType))
|
||||
return failure();
|
||||
|
||||
if (parser.resolveOperand(condition, conditionType, result.operands) ||
|
||||
parser.resolveOperand(trueValue, argType, result.operands) ||
|
||||
parser.resolveOperand(falseValue, argType, result.operands))
|
||||
return failure();
|
||||
|
||||
result.addTypes(argType);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::BrOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1065,54 +1065,6 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
|
|||
results.insert<ConvertChainedBitcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BitFieldInsert
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBitFieldInsertOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||
Type baseType;
|
||||
Type offsetType;
|
||||
Type countType;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
|
||||
if (parser.parseOperandList(operandInfo, 4) || parser.parseColon() ||
|
||||
parser.parseType(baseType) || parser.parseComma() ||
|
||||
parser.parseType(offsetType) || parser.parseComma() ||
|
||||
parser.parseType(countType) ||
|
||||
parser.resolveOperands(operandInfo,
|
||||
{baseType, baseType, offsetType, countType}, loc,
|
||||
state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
state.addTypes(baseType);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << spirv::BitFieldInsertOp::getOperationName() << ' '
|
||||
<< bitFieldInsertOp.getOperands() << " : "
|
||||
<< bitFieldInsertOp.base().getType() << ", "
|
||||
<< bitFieldInsertOp.offset().getType() << ", "
|
||||
<< bitFieldInsertOp.count().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) {
|
||||
auto baseType = bitFieldOp.base().getType();
|
||||
auto insertType = bitFieldOp.insert().getType();
|
||||
auto resultType = bitFieldOp.getResult().getType();
|
||||
|
||||
if ((baseType != insertType) || (baseType != resultType)) {
|
||||
return bitFieldOp.emitError("expected the same type for the base operand, "
|
||||
"insert operand, and "
|
||||
"result, but provided ")
|
||||
<< baseType << ", " << insertType << " and " << resultType;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2522,42 +2474,9 @@ void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
|
|||
build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
|
||||
}
|
||||
|
||||
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
|
||||
OpAsmParser::OperandType condition;
|
||||
SmallVector<OpAsmParser::OperandType, 2> operands;
|
||||
SmallVector<Type, 2> types;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseOperand(condition) || parser.parseComma() ||
|
||||
parser.parseOperandList(operands, 2) ||
|
||||
parser.parseColonTypeList(types)) {
|
||||
return failure();
|
||||
}
|
||||
if (types.size() != 2) {
|
||||
return parser.emitError(
|
||||
loc, "need exactly two trailing types for select condition and object");
|
||||
}
|
||||
if (parser.resolveOperand(condition, types[0], state.operands) ||
|
||||
parser.resolveOperands(operands, types[1], state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
return parser.addTypesToList(types[1], state.types);
|
||||
}
|
||||
|
||||
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
|
||||
<< " : " << op.condition().getType() << ", " << op.result().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SelectOp op) {
|
||||
auto resultTy = op.result().getType();
|
||||
if (op.true_value().getType() != resultTy) {
|
||||
return op.emitOpError("result type and true value type must be the same");
|
||||
}
|
||||
if (op.false_value().getType() != resultTy) {
|
||||
return op.emitOpError("result type and false value type must be the same");
|
||||
}
|
||||
if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
|
||||
auto resultVectorTy = resultTy.dyn_cast<VectorType>();
|
||||
auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
|
||||
if (!resultVectorTy) {
|
||||
return op.emitOpError("result expected to be of vector type when "
|
||||
"condition is of vector type");
|
||||
|
|
|
@ -700,31 +700,6 @@ void InsertOp::build(Builder *builder, OperationState &result, Value source,
|
|||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertOp op) {
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
|
||||
p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
OpAsmParser::OperandType source, dest;
|
||||
Type sourceType;
|
||||
VectorType destType;
|
||||
Attribute attr;
|
||||
return failure(parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) ||
|
||||
parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(attrs) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.parseKeywordType("into", destType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.resolveOperand(dest, destType, result.operands) ||
|
||||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
|
@ -793,27 +768,6 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
|
|||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType source, dest;
|
||||
VectorType sourceVectorType, destVectorType;
|
||||
return failure(
|
||||
parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(sourceVectorType) ||
|
||||
parser.parseKeywordType("into", destVectorType) ||
|
||||
parser.resolveOperand(source, sourceVectorType, result.operands) ||
|
||||
parser.resolveOperand(dest, destVectorType, result.operands) ||
|
||||
parser.addTypeToList(destVectorType, result.types));
|
||||
}
|
||||
|
||||
// TODO(ntv) Should be moved to Tablegen Confined attributes.
|
||||
template <typename OpType>
|
||||
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
|
||||
|
|
|
@ -227,7 +227,7 @@ func @bit_field_insert_vec(%base: vector<3xi32>, %insert: vector<3xi32>, %offset
|
|||
// -----
|
||||
|
||||
func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> {
|
||||
// expected-error @+1 {{expected the same type for the base operand, insert operand, and result, but provided 'vector<3xi32>', 'vector<2xi32>' and 'vector<3xi32>'}}
|
||||
// expected-error @+1 {{all of {base, insert, result} have same type}}
|
||||
%0 = "spv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32>
|
||||
spv.ReturnValue %0 : vector<3xi32>
|
||||
}
|
||||
|
@ -856,7 +856,7 @@ func @select_op_vec_condn_vec(%arg0: vector<3xi1>) -> () {
|
|||
func @select_op(%arg0: i1) -> () {
|
||||
%0 = spv.constant 2 : i32
|
||||
%1 = spv.constant 3 : i32
|
||||
// expected-error @+1 {{need exactly two trailing types for select condition and object}}
|
||||
// expected-error @+2 {{expected ','}}
|
||||
%2 = spv.Select %arg0, %0, %1 : i1
|
||||
return
|
||||
}
|
||||
|
@ -886,7 +886,7 @@ func @select_op(%arg1: vector<4xi1>) -> () {
|
|||
func @select_op(%arg1: vector<4xi1>) -> () {
|
||||
%0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
|
||||
%1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
|
||||
// expected-error @+1 {{op result type and true value type must be the same}}
|
||||
// expected-error @+1 {{all of {true_value, false_value, result} have same type}}
|
||||
%2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32>
|
||||
return
|
||||
}
|
||||
|
@ -896,7 +896,7 @@ func @select_op(%arg1: vector<4xi1>) -> () {
|
|||
func @select_op(%arg1: vector<4xi1>) -> () {
|
||||
%0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
|
||||
%1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
|
||||
// expected-error @+1 {{op result type and false value type must be the same}}
|
||||
// expected-error @+1 {{all of {true_value, false_value, result} have same type}}
|
||||
%2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -7,7 +7,8 @@ include "mlir/IR/OpBase.td"
|
|||
def TestDialect : Dialect {
|
||||
let name = "test";
|
||||
}
|
||||
class TestFormat_Op<string name, string fmt> : Op<TestDialect, name> {
|
||||
class TestFormat_Op<string name, string fmt, list<OpTrait> traits = []>
|
||||
: Op<TestDialect, name, traits> {
|
||||
let assemblyFormat = fmt;
|
||||
}
|
||||
|
||||
|
@ -234,3 +235,24 @@ def ZCoverageValidC : TestFormat_Op<"variable_valid_c", [{
|
|||
operands functional-type(operands, results) attr-dict
|
||||
}]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
|
||||
|
||||
// Check that we can infer type equalities from certain traits.
|
||||
def ZCoverageValidD : TestFormat_Op<"variable_valid_d", [{
|
||||
operands type($result) attr-dict
|
||||
}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
|
||||
Results<(outs AnyMemRef:$result)>;
|
||||
def ZCoverageValidE : TestFormat_Op<"variable_valid_e", [{
|
||||
$operand type($operand) attr-dict
|
||||
}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
|
||||
Results<(outs AnyMemRef:$result)>;
|
||||
def ZCoverageValidF : TestFormat_Op<"variable_valid_f", [{
|
||||
operands type($other) attr-dict
|
||||
}], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
|
||||
def ZCoverageValidG : TestFormat_Op<"variable_valid_g", [{
|
||||
operands type($other) attr-dict
|
||||
}], [AllTypesMatch<["operand", "other"]>]>,
|
||||
Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
|
||||
def ZCoverageValidH : TestFormat_Op<"variable_valid_h", [{
|
||||
operands type($result) attr-dict
|
||||
}], [AllTypesMatch<["operand", "result"]>]>,
|
||||
Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
|
||||
|
||||
|
|
|
@ -202,10 +202,32 @@ bool LiteralElement::isValidLiteral(StringRef value) {
|
|||
|
||||
namespace {
|
||||
struct OperationFormat {
|
||||
/// This class represents a specific resolver for an operand or result type.
|
||||
class TypeResolution {
|
||||
public:
|
||||
TypeResolution() = default;
|
||||
|
||||
/// Get the index into the buildable types for this type, or None.
|
||||
Optional<int> getBuilderIdx() const { return builderIdx; }
|
||||
void setBuilderIdx(int idx) { builderIdx = idx; }
|
||||
|
||||
/// Get the variable this type is resolved to, or None.
|
||||
Optional<StringRef> getVariable() const { return variableName; }
|
||||
void setVariable(StringRef variable) { variableName = variable; }
|
||||
|
||||
private:
|
||||
/// If the type is resolved with a buildable type, this is the index into
|
||||
/// 'buildableTypes' in the parent format.
|
||||
Optional<int> builderIdx;
|
||||
/// If the type is resolved based upon another operand or result, this is
|
||||
/// the name of the variable that this type is resolved to.
|
||||
Optional<StringRef> variableName;
|
||||
};
|
||||
|
||||
OperationFormat(const Operator &op)
|
||||
: allOperandTypes(false), allResultTypes(false) {
|
||||
buildableOperandTypes.resize(op.getNumOperands(), llvm::None);
|
||||
buildableResultTypes.resize(op.getNumResults(), llvm::None);
|
||||
operandTypes.resize(op.getNumOperands(), TypeResolution());
|
||||
resultTypes.resize(op.getNumResults(), TypeResolution());
|
||||
}
|
||||
|
||||
/// Generate the operation parser from this format.
|
||||
|
@ -228,7 +250,7 @@ struct OperationFormat {
|
|||
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
|
||||
|
||||
/// The index of the buildable type, if valid, for every operand and result.
|
||||
std::vector<Optional<int>> buildableOperandTypes, buildableResultTypes;
|
||||
std::vector<TypeResolution> operandTypes, resultTypes;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -398,8 +420,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
} else {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
body << " result.addTypes(";
|
||||
if (Optional<int> val = buildableResultTypes[i])
|
||||
if (Optional<int> val = resultTypes[i].getBuilderIdx())
|
||||
body << "odsBuildableType" << *val;
|
||||
else if (Optional<StringRef> var = resultTypes[i].getVariable())
|
||||
body << *var << "Types";
|
||||
else
|
||||
body << op.getResultName(i) << "Types";
|
||||
body << ");\n";
|
||||
|
@ -450,8 +474,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
if (op.getNumOperands() > 1) {
|
||||
body << "llvm::concat<const Type>(";
|
||||
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
||||
if (Optional<int> val = buildableOperandTypes[i])
|
||||
if (Optional<int> val = operandTypes[i].getBuilderIdx())
|
||||
body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
|
||||
else if (Optional<StringRef> var = operandTypes[i].getVariable())
|
||||
body << *var << "Types";
|
||||
else
|
||||
body << op.getOperand(i).name << "Types";
|
||||
});
|
||||
|
@ -470,8 +496,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
NamedTypeConstraint &operand = op.getOperand(i);
|
||||
body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
|
||||
if (Optional<int> val = buildableOperandTypes[i])
|
||||
if (Optional<int> val = operandTypes[i].getBuilderIdx())
|
||||
body << "odsBuildableType" << *val << ", ";
|
||||
else if (Optional<StringRef> var = operandTypes[i].getVariable())
|
||||
body << *var << "Types, " << operand.name << "OperandsLoc, ";
|
||||
else
|
||||
body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
|
||||
body << "result.operands))\n return failure();\n";
|
||||
|
@ -803,6 +831,13 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
|
|||
// FormatParser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Function to find an element within the given range that has the same name as
|
||||
/// 'name'.
|
||||
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This class implements a parser for an instance of an operation assembly
|
||||
/// format.
|
||||
|
@ -817,6 +852,18 @@ public:
|
|||
LogicalResult parse();
|
||||
|
||||
private:
|
||||
/// Given the values of an `AllTypesMatch` trait, check for inferrable type
|
||||
/// resolution.
|
||||
void handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
|
||||
/// Check for inferrable type resolution given all operands, and or results,
|
||||
/// have the same type. If 'includeResults' is true, the results also have the
|
||||
/// same type as all of the operands.
|
||||
void handleSameTypesConstraint(
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
|
||||
bool includeResults);
|
||||
|
||||
/// Parse a specific element.
|
||||
LogicalResult parseElement(std::unique_ptr<Element> &element,
|
||||
bool isTopLevel);
|
||||
|
@ -870,8 +917,8 @@ private:
|
|||
OperationFormat &fmt;
|
||||
Operator &op;
|
||||
|
||||
// The following are various bits of format state used for verification during
|
||||
// parsing.
|
||||
// The following are various bits of format state used for verification
|
||||
// during parsing.
|
||||
bool hasAllOperands = false, hasAttrDict = false;
|
||||
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
|
||||
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
|
||||
|
@ -894,6 +941,19 @@ LogicalResult FormatParser::parse() {
|
|||
if (!hasAttrDict)
|
||||
return emitError(loc, "format missing 'attr-dict' directive");
|
||||
|
||||
// Check for any type traits that we can use for inferring types.
|
||||
llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
|
||||
for (const OpTrait &trait : op.getTraits()) {
|
||||
const llvm::Record &def = trait.getDef();
|
||||
if (def.isSubClassOf("AllTypesMatch"))
|
||||
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
|
||||
variableTyResolver);
|
||||
else if (def.getName() == "SameTypeOperands")
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
|
||||
else if (def.getName() == "SameOperandsAndResultType")
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
|
||||
}
|
||||
|
||||
// Check that all of the result types can be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
if (!fmt.allResultTypes) {
|
||||
|
@ -901,6 +961,13 @@ LogicalResult FormatParser::parse() {
|
|||
if (seenResultTypes.test(i))
|
||||
continue;
|
||||
|
||||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second->name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the result is not variadic, allow for the case where the type has a
|
||||
// builder that we can use.
|
||||
NamedTypeConstraint &result = op.getResult(i);
|
||||
|
@ -911,7 +978,7 @@ LogicalResult FormatParser::parse() {
|
|||
}
|
||||
// Note in the format that this result uses the custom builder.
|
||||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.buildableResultTypes[i] = it.first->second;
|
||||
fmt.resultTypes[i].setBuilderIdx(it.first->second);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -927,21 +994,78 @@ LogicalResult FormatParser::parse() {
|
|||
}
|
||||
|
||||
// Check that the operand type is in the format, or that it can be inferred.
|
||||
if (!fmt.allOperandTypes && !seenOperandTypes.test(i)) {
|
||||
// Similarly to results, allow a custom builder for resolving the type if
|
||||
// we aren't using the 'operands' directive.
|
||||
Optional<StringRef> builder = operand.constraint.getBuilderCall();
|
||||
if (!builder || (hasAllOperands && operand.isVariadic())) {
|
||||
return emitError(loc, "format missing instance of operand #" +
|
||||
Twine(i) + "('" + operand.name + "') type");
|
||||
}
|
||||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.buildableOperandTypes[i] = it.first->second;
|
||||
if (fmt.allOperandTypes || seenOperandTypes.test(i))
|
||||
continue;
|
||||
|
||||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.operandTypes[i].setVariable(varResolverIt->second->name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Similarly to results, allow a custom builder for resolving the type if
|
||||
// we aren't using the 'operands' directive.
|
||||
Optional<StringRef> builder = operand.constraint.getBuilderCall();
|
||||
if (!builder || (hasAllOperands && operand.isVariadic())) {
|
||||
return emitError(loc, "format missing instance of operand #" + Twine(i) +
|
||||
"('" + operand.name + "') type");
|
||||
}
|
||||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.operandTypes[i].setBuilderIdx(it.first->second);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void FormatParser::handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
// Check to see if this value matches a resolved operand or result type.
|
||||
const NamedTypeConstraint *arg = nullptr;
|
||||
if ((arg = findArg(op.getOperands(), values[i]))) {
|
||||
if (!seenOperandTypes.test(arg - op.operand_begin()))
|
||||
continue;
|
||||
} else if ((arg = findArg(op.getResults(), values[i]))) {
|
||||
if (!seenResultTypes.test(arg - op.result_begin()))
|
||||
continue;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mark this value as the type resolver for the other variables.
|
||||
for (unsigned j = 0; j != i; ++j)
|
||||
variableTyResolver[values[j]] = arg;
|
||||
for (unsigned j = i + 1; j != e; ++j)
|
||||
variableTyResolver[values[j]] = arg;
|
||||
}
|
||||
}
|
||||
|
||||
void FormatParser::handleSameTypesConstraint(
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
|
||||
bool includeResults) {
|
||||
const NamedTypeConstraint *resolver = nullptr;
|
||||
int resolvedIt = -1;
|
||||
|
||||
// Check to see if there is an operand or result to use for the resolution.
|
||||
if ((resolvedIt = seenOperandTypes.find_first()) != -1)
|
||||
resolver = &op.getOperand(resolvedIt);
|
||||
else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
|
||||
resolver = &op.getResult(resolvedIt);
|
||||
else
|
||||
return;
|
||||
|
||||
// Set the resolvers for each operand and result.
|
||||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
|
||||
if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
|
||||
variableTyResolver[op.getOperand(i).name] = resolver;
|
||||
if (includeResults) {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
|
||||
if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
|
||||
variableTyResolver[op.getResultName(i)] = resolver;
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
|
||||
bool isTopLevel) {
|
||||
// Directives.
|
||||
|
@ -965,23 +1089,16 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
|||
StringRef name = varTok.getSpelling().drop_front();
|
||||
llvm::SMLoc loc = varTok.getLoc();
|
||||
|
||||
// Functor used to find an element within the given range that has the same
|
||||
// name as 'name'.
|
||||
auto findArg = [&](auto &&range) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
};
|
||||
|
||||
// Check that the parsed argument is something actually registered on the op.
|
||||
/// Attributes
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes())) {
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
|
||||
if (isTopLevel && !seenAttrs.insert(attr).second)
|
||||
return emitError(loc, "attribute '" + name + "' is already bound");
|
||||
element = std::make_unique<AttributeVariable>(attr);
|
||||
return success();
|
||||
}
|
||||
/// Operands
|
||||
if (const NamedTypeConstraint *operand = findArg(op.getOperands())) {
|
||||
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
|
||||
if (isTopLevel) {
|
||||
if (hasAllOperands || !seenOperands.insert(operand).second)
|
||||
return emitError(loc, "operand '" + name + "' is already bound");
|
||||
|
@ -990,7 +1107,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
|||
return success();
|
||||
}
|
||||
/// Results.
|
||||
if (const NamedTypeConstraint *result = findArg(op.getResults())) {
|
||||
if (const auto *result = findArg(op.getResults(), name)) {
|
||||
if (isTopLevel)
|
||||
return emitError(loc, "results can not be used at the top level");
|
||||
element = std::make_unique<ResultVariable>(result);
|
||||
|
|
Loading…
Reference in New Issue