forked from OSchip/llvm-project
[mlir][DeclarativeParser] Add support for the TypesMatchWith trait.
This allows for injecting type constraints that are not direct 1-1 mappings, for example when one type is equal to the element type of another. This allows for moving over several more parsers to the declarative form. Differential Revision: https://reviews.llvm.org/D74648
This commit is contained in:
parent
393f4e8ac2
commit
26222db01b
|
@ -357,6 +357,8 @@ def CallIndirectOp : Std_Op<"call_indirect", [
|
|||
|
||||
let verifier = ?;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)";
|
||||
}
|
||||
|
||||
def CeilFOp : FloatUnaryOp<"ceilf"> {
|
||||
|
@ -490,6 +492,8 @@ def CmpIOp : Std_Op<"cmpi",
|
|||
let verifier = [{ return success(); }];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
|
||||
}
|
||||
|
||||
def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
|
||||
|
@ -761,6 +765,10 @@ def ExtractElementOp : Std_Op<"extract_element",
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$aggregate `[` $indices `]` attr-dict `:` type($aggregate)
|
||||
}];
|
||||
}
|
||||
|
||||
def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
|
||||
|
@ -853,6 +861,8 @@ def LoadOp : Std_Op<"load",
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def LogOp : FloatUnaryOp<"log"> {
|
||||
|
@ -1090,6 +1100,10 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$condition `,` $true_value `,` $false_value attr-dict `:` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def SignExtendIOp : Std_Op<"sexti",
|
||||
|
@ -1222,6 +1236,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
|
|||
[{ build(builder, result, aggregateType, element); }]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
|
||||
}
|
||||
|
||||
def StoreOp : Std_Op<"store",
|
||||
|
@ -1264,6 +1280,10 @@ def StoreOp : Std_Op<"store",
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
|
||||
}];
|
||||
}
|
||||
|
||||
def SubFOp : FloatArithmeticOp<"subf"> {
|
||||
|
@ -1517,11 +1537,12 @@ def TensorLoadOp : Std_Op<"tensor_load",
|
|||
result.addTypes(resultType);
|
||||
}]>];
|
||||
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// The result of a tensor_load is always a tensor.
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def TensorStoreOp : Std_Op<"tensor_store",
|
||||
|
@ -1545,6 +1566,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
|
|||
let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
|
||||
// TensorStoreOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
||||
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
|
|
|
@ -363,6 +363,10 @@ def Vector_ExtractElementOp :
|
|||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractOp :
|
||||
|
@ -512,6 +516,11 @@ def Vector_InsertElementOp :
|
|||
return dest().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
|
||||
type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertOp :
|
||||
|
|
|
@ -496,6 +496,12 @@ public:
|
|||
return failure();
|
||||
return success();
|
||||
}
|
||||
template <typename Operands>
|
||||
ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
|
||||
SmallVectorImpl<Value> &result) {
|
||||
return resolveOperands(std::forward<Operands>(operands),
|
||||
ArrayRef<Type>(type), loc, result);
|
||||
}
|
||||
template <typename Operands, typename Types>
|
||||
ParseResult resolveOperands(Operands &&operands, Types &&types,
|
||||
llvm::SMLoc loc, SmallVectorImpl<Value> &result) {
|
||||
|
|
|
@ -294,6 +294,11 @@ public:
|
|||
void addTypes(ArrayRef<Type> newTypes) {
|
||||
types.append(newTypes.begin(), newTypes.end());
|
||||
}
|
||||
template <typename RangeT>
|
||||
std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
|
||||
addTypes(RangeT &&newTypes) {
|
||||
types.append(newTypes.begin(), newTypes.end());
|
||||
}
|
||||
|
||||
/// Add an attribute with the specified name.
|
||||
void addAttribute(StringRef name, Attribute attr) {
|
||||
|
|
|
@ -505,29 +505,6 @@ struct SimplifyIndirectCallWithKnownCallee
|
|||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
static ParseResult parseCallIndirectOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
FunctionType calleeType;
|
||||
OpAsmParser::OperandType callee;
|
||||
llvm::SMLoc operandsLoc;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
return failure(
|
||||
parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) ||
|
||||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(calleeType) ||
|
||||
parser.resolveOperand(callee, calleeType, result.operands) ||
|
||||
parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc,
|
||||
result.operands) ||
|
||||
parser.addTypesToList(calleeType.getResults(), result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CallIndirectOp op) {
|
||||
p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
p << " : " << op.getCallee().getType();
|
||||
}
|
||||
|
||||
void CallIndirectOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<SimplifyIndirectCallWithKnownCallee>(context);
|
||||
|
@ -570,55 +547,6 @@ static void buildCmpIOp(Builder *build, OperationState &result,
|
|||
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
|
||||
}
|
||||
|
||||
static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
Attribute predicateNameAttr;
|
||||
Type type;
|
||||
if (parser.parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
|
||||
attrs) ||
|
||||
parser.parseComma() || parser.parseOperandList(ops, 2) ||
|
||||
parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
|
||||
parser.resolveOperands(ops, type, result.operands))
|
||||
return failure();
|
||||
|
||||
if (!predicateNameAttr.isa<StringAttr>())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected string comparison predicate attribute");
|
||||
|
||||
// Rewrite string attribute to an enum value.
|
||||
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
|
||||
Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
|
||||
if (!predicate.hasValue())
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "unknown comparison predicate \"" << predicateName << "\"";
|
||||
|
||||
auto builder = parser.getBuilder();
|
||||
Type i1Type = getCheckedI1SameShape(type);
|
||||
if (!i1Type)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected type with valid i1 shape");
|
||||
|
||||
attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
|
||||
result.attributes = attrs;
|
||||
|
||||
result.addTypes({i1Type});
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CmpIOp op) {
|
||||
p << "cmpi ";
|
||||
|
||||
Builder b(op.getContext());
|
||||
auto predicateValue =
|
||||
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
|
||||
p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
|
||||
<< '"' << ", " << op.lhs() << ", " << op.rhs();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||
// comparison predicates.
|
||||
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||
|
@ -1486,30 +1414,6 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ExtractElementOp op) {
|
||||
p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getAggregate().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType aggregateInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
ShapedType type;
|
||||
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(aggregateInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(aggregateInfo, type, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type.getElementType(), result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractElementOp op) {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
|
||||
|
@ -1577,28 +1481,6 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, LoadOp op) {
|
||||
p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
||||
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType type;
|
||||
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(memrefInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(memrefInfo, type, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type.getElementType(), result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(LoadOp op) {
|
||||
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
|
||||
return op.emitOpError("incorrect number of indices for load");
|
||||
|
@ -1902,31 +1784,6 @@ bool SIToFPOp::areCastCompatible(Type a, Type b) {
|
|||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
Type type;
|
||||
if (parser.parseOperandList(ops, 3) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
auto i1Type = getCheckedI1SameShape(type);
|
||||
if (!i1Type)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected type with valid i1 shape");
|
||||
|
||||
std::array<Type, 3> types = {i1Type, type, type};
|
||||
return failure(parser.resolveOperands(ops, types, parser.getNameLoc(),
|
||||
result.operands) ||
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, SelectOp op) {
|
||||
p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto condition = getCondition();
|
||||
|
||||
|
@ -1968,25 +1825,6 @@ static LogicalResult verify(SignExtendIOp op) {
|
|||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, SplatOp op) {
|
||||
p << "splat " << op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType splatValueInfo;
|
||||
ShapedType shapedType;
|
||||
|
||||
return failure(parser.parseOperand(splatValueInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(shapedType) ||
|
||||
parser.resolveOperand(splatValueInfo,
|
||||
shapedType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.addTypeToList(shapedType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(SplatOp op) {
|
||||
// TODO: we could replace this by a trait.
|
||||
if (op.getOperand().getType() !=
|
||||
|
@ -2017,32 +1855,6 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, StoreOp op) {
|
||||
p << "store " << op.getValueToStore();
|
||||
p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
||||
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType storeValueInfo;
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType memrefType;
|
||||
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(memrefInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(memrefType) ||
|
||||
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
|
||||
return op.emitOpError("store index operand count not equal to memref rank");
|
||||
|
@ -2156,51 +1968,6 @@ static Type getTensorTypeFromMemRefType(Type type) {
|
|||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, TensorLoadOp op) {
|
||||
p << "tensor_load " << op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType op;
|
||||
Type type;
|
||||
return failure(
|
||||
parser.parseOperand(op) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(op, type, result.operands) ||
|
||||
parser.addTypeToList(getTensorTypeFromMemRefType(type), result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, TensorStoreOp op) {
|
||||
p << "tensor_store " << op.tensor() << ", " << op.memref();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.memref().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||
Type type;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
return failure(
|
||||
parser.parseOperandList(ops, /*requiredOperandCount=*/2) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperands(ops, {getTensorTypeFromMemRefType(type), type},
|
||||
loc, result.operands));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TruncateIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -412,31 +412,6 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
|||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
|
||||
p << op.getOperationName() << " " << op.vector() << "[" << op.position()
|
||||
<< " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType vector, position;
|
||||
Type positionType;
|
||||
VectorType vectorType;
|
||||
if (parser.parseOperand(vector) || parser.parseLSquare() ||
|
||||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
|
||||
parser.parseRSquare() ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(vectorType))
|
||||
return failure();
|
||||
Type resultType = vectorType.getElementType();
|
||||
return failure(
|
||||
parser.resolveOperand(vector, vectorType, result.operands) ||
|
||||
parser.resolveOperand(position, positionType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(vector::ExtractElementOp op) {
|
||||
VectorType vectorType = op.getVectorType();
|
||||
if (vectorType.getRank() != 1)
|
||||
|
@ -715,33 +690,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
|
|||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertElementOp op) {
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
|
||||
<< op.position() << " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.dest().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType source, dest, position;
|
||||
Type positionType;
|
||||
VectorType destType;
|
||||
if (parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) || parser.parseLSquare() ||
|
||||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
|
||||
parser.parseRSquare() ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(destType))
|
||||
return failure();
|
||||
Type sourceType = destType.getElementType();
|
||||
return failure(
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.resolveOperand(dest, destType, result.operands) ||
|
||||
parser.resolveOperand(position, positionType, result.operands) ||
|
||||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
auto dstVectorType = op.getDestVectorType();
|
||||
if (dstVectorType.getRank() != 1)
|
||||
|
|
|
@ -226,7 +226,7 @@ func @func_with_ops(i32, i32) {
|
|||
// Integer comparisons are not recognized for float types.
|
||||
func @func_with_ops(f32, f32) {
|
||||
^bb0(%a : f32, %b : f32):
|
||||
%r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}}
|
||||
%r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be integer-like, but got 'f32'}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -298,13 +298,13 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
|
|||
// -----
|
||||
|
||||
func @invalid_select_shape(%cond : i1, %idx : () -> ()) {
|
||||
// expected-error@+1 {{expected type with valid i1 shape}}
|
||||
// expected-error@+1 {{'result' must be integer-like or floating-point-like, but got '() -> ()'}}
|
||||
%sel = select %cond, %idx, %idx : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_cmp_shape(%idx : () -> ()) {
|
||||
// expected-error@+1 {{expected type with valid i1 shape}}
|
||||
// expected-error@+1 {{'lhs' must be integer-like, but got '() -> ()'}}
|
||||
%cmp = cmpi "eq", %idx, %idx : () -> ()
|
||||
|
||||
// -----
|
||||
|
@ -340,7 +340,7 @@ func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
|
|||
// -----
|
||||
|
||||
func @invalid_cmp_attr(%idx : i32) {
|
||||
// expected-error@+1 {{expected string comparison predicate attribute}}
|
||||
// expected-error@+1 {{invalid kind of attribute specified}}
|
||||
%cmp = cmpi i1, %idx, %idx : i32
|
||||
|
||||
// -----
|
||||
|
|
|
@ -219,16 +219,26 @@ struct OperationFormat {
|
|||
void setBuilderIdx(int idx) { builderIdx = idx; }
|
||||
|
||||
/// Get the variable this type is resolved to, or None.
|
||||
Optional<StringRef> getVariable() const { return variableName; }
|
||||
void setVariable(StringRef variable) { variableName = variable; }
|
||||
const NamedTypeConstraint *getVariable() const { return variable; }
|
||||
Optional<StringRef> getVarTransformer() const {
|
||||
return variableTransformer;
|
||||
}
|
||||
void setVariable(const NamedTypeConstraint *var,
|
||||
Optional<StringRef> transformer) {
|
||||
variable = var;
|
||||
variableTransformer = transformer;
|
||||
}
|
||||
|
||||
private:
|
||||
/// If the type is resolved with a buildable type, this is the index into
|
||||
/// 'buildableTypes' in the parent format.
|
||||
Optional<int> builderIdx;
|
||||
/// If the type is resolved based upon another operand or result, this is
|
||||
/// the name of the variable that this type is resolved to.
|
||||
Optional<StringRef> variableName;
|
||||
/// the variable that this type is resolved to.
|
||||
const NamedTypeConstraint *variable;
|
||||
/// If the type is resolved based upon another operand or result, this is
|
||||
/// a transformer to apply to the variable when resolving.
|
||||
Optional<StringRef> variableTransformer;
|
||||
};
|
||||
|
||||
OperationFormat(const Operator &op)
|
||||
|
@ -487,6 +497,34 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
|
|||
|
||||
void OperationFormat::genParserTypeResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
// If any of type resolutions use transformed variables, make sure that the
|
||||
// types of those variables are resolved.
|
||||
SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
|
||||
FmtContext verifierFCtx;
|
||||
for (TypeResolution &resolver :
|
||||
llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
|
||||
Optional<StringRef> transformer = resolver.getVarTransformer();
|
||||
if (!transformer)
|
||||
continue;
|
||||
// Ensure that we don't verify the same variables twice.
|
||||
const NamedTypeConstraint *variable = resolver.getVariable();
|
||||
if (!verifiedVariables.insert(variable).second)
|
||||
continue;
|
||||
|
||||
auto constraint = variable->constraint;
|
||||
body << " for (Type type : " << variable->name << "Types) {\n"
|
||||
<< " (void)type;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(constraint.getConditionTemplate(),
|
||||
&verifierFCtx.withSelf("type"))
|
||||
<< ")) {\n"
|
||||
<< formatv(" return parser.emitError(parser.getNameLoc()) << "
|
||||
"\"'{0}' must be {1}, but got \" << type;\n",
|
||||
variable->name, constraint.getDescription())
|
||||
<< " }\n"
|
||||
<< " }\n";
|
||||
}
|
||||
|
||||
// Initialize the set of buildable types.
|
||||
if (!buildableTypes.empty()) {
|
||||
body << " Builder &builder = parser.getBuilder();\n";
|
||||
|
@ -498,18 +536,27 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
<< tgfmt(it.first, &typeBuilderCtx) << ";\n";
|
||||
}
|
||||
|
||||
// Emit the code necessary for a type resolver.
|
||||
auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
|
||||
if (Optional<int> val = resolver.getBuilderIdx()) {
|
||||
body << "odsBuildableType" << *val;
|
||||
} else if (const NamedTypeConstraint *var = resolver.getVariable()) {
|
||||
if (Optional<StringRef> tform = resolver.getVarTransformer())
|
||||
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
|
||||
else
|
||||
body << var->name << "Types";
|
||||
} else {
|
||||
body << curVar << "Types";
|
||||
}
|
||||
};
|
||||
|
||||
// Resolve each of the result types.
|
||||
if (allResultTypes) {
|
||||
body << " result.addTypes(allResultTypes);\n";
|
||||
} else {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
body << " result.addTypes(";
|
||||
if (Optional<int> val = resultTypes[i].getBuilderIdx())
|
||||
body << "odsBuildableType" << *val;
|
||||
else if (Optional<StringRef> var = resultTypes[i].getVariable())
|
||||
body << *var << "Types";
|
||||
else
|
||||
body << op.getResultName(i) << "Types";
|
||||
emitTypeResolver(resultTypes[i], op.getResultName(i));
|
||||
body << ");\n";
|
||||
}
|
||||
}
|
||||
|
@ -552,25 +599,19 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
if (hasAllOperands) {
|
||||
body << " if (parser.resolveOperands(allOperands, ";
|
||||
|
||||
auto emitOperandType = [&](int idx) {
|
||||
if (Optional<int> val = operandTypes[idx].getBuilderIdx())
|
||||
body << "ArrayRef<Type>(odsBuildableType" << *val << ")";
|
||||
else if (Optional<StringRef> var = operandTypes[idx].getVariable())
|
||||
body << *var << "Types";
|
||||
else
|
||||
body << op.getOperand(idx).name << "Types";
|
||||
};
|
||||
|
||||
// Group all of the operand types together to perform the resolution all at
|
||||
// once. Use llvm::concat to perform the merge. llvm::concat does not allow
|
||||
// the case of a single range, so guard it here.
|
||||
if (op.getNumOperands() > 1) {
|
||||
body << "llvm::concat<const Type>(";
|
||||
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body,
|
||||
emitOperandType);
|
||||
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
||||
body << "ArrayRef<Type>(";
|
||||
emitTypeResolver(operandTypes[i], op.getOperand(i).name);
|
||||
body << ")";
|
||||
});
|
||||
body << ")";
|
||||
} else {
|
||||
emitOperandType(/*idx=*/0);
|
||||
emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
|
||||
}
|
||||
|
||||
body << ", allOperandLoc, result.operands))\n"
|
||||
|
@ -583,13 +624,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
NamedTypeConstraint &operand = op.getOperand(i);
|
||||
body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
|
||||
if (Optional<int> val = operandTypes[i].getBuilderIdx())
|
||||
body << "odsBuildableType" << *val << ", ";
|
||||
else if (Optional<StringRef> var = operandTypes[i].getVariable())
|
||||
body << *var << "Types, " << operand.name << "OperandsLoc, ";
|
||||
else
|
||||
body << operand.name << "Types, " << operand.name << "OperandsLoc, ";
|
||||
body << "result.operands))\n return failure();\n";
|
||||
emitTypeResolver(operandTypes[i], operand.name);
|
||||
|
||||
// If this isn't a buildable type, verify the sizes match by adding the loc.
|
||||
if (!operandTypes[i].getBuilderIdx())
|
||||
body << ", " << operand.name << "OperandsLoc";
|
||||
body << ", result.operands))\n return failure();\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -954,18 +994,30 @@ public:
|
|||
LogicalResult parse();
|
||||
|
||||
private:
|
||||
/// This struct represents a type resolution instance. It includes a specific
|
||||
/// type as well as an optional transformer to apply to that type in order to
|
||||
/// properly resolve the type of a variable.
|
||||
struct TypeResolutionInstance {
|
||||
const NamedTypeConstraint *type;
|
||||
Optional<StringRef> transformer;
|
||||
};
|
||||
|
||||
/// Given the values of an `AllTypesMatch` trait, check for inferrable type
|
||||
/// resolution.
|
||||
void handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver);
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
/// Check for inferrable type resolution given all operands, and or results,
|
||||
/// have the same type. If 'includeResults' is true, the results also have the
|
||||
/// same type as all of the operands.
|
||||
void handleSameTypesConstraint(
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
bool includeResults);
|
||||
|
||||
/// Returns an argument with the given name that has been seen within the
|
||||
/// format.
|
||||
const NamedTypeConstraint *findSeenArg(StringRef name);
|
||||
|
||||
/// Parse a specific element.
|
||||
LogicalResult parseElement(std::unique_ptr<Element> &element,
|
||||
bool isTopLevel);
|
||||
|
@ -1044,16 +1096,21 @@ LogicalResult FormatParser::parse() {
|
|||
return emitError(loc, "format missing 'attr-dict' directive");
|
||||
|
||||
// Check for any type traits that we can use for inferring types.
|
||||
llvm::StringMap<const NamedTypeConstraint *> variableTyResolver;
|
||||
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
|
||||
for (const OpTrait &trait : op.getTraits()) {
|
||||
const llvm::Record &def = trait.getDef();
|
||||
if (def.isSubClassOf("AllTypesMatch"))
|
||||
if (def.isSubClassOf("AllTypesMatch")) {
|
||||
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
|
||||
variableTyResolver);
|
||||
else if (def.getName() == "SameTypeOperands")
|
||||
} else if (def.getName() == "SameTypeOperands") {
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
|
||||
else if (def.getName() == "SameOperandsAndResultType")
|
||||
} else if (def.getName() == "SameOperandsAndResultType") {
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
|
||||
} else if (def.isSubClassOf("TypesMatchWith")) {
|
||||
if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
|
||||
variableTyResolver[def.getValueAsString("rhs")] = {
|
||||
lhsArg, def.getValueAsString("transformer")};
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all of the result types can be inferred.
|
||||
|
@ -1066,7 +1123,8 @@ LogicalResult FormatParser::parse() {
|
|||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second->name);
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1102,7 +1160,8 @@ LogicalResult FormatParser::parse() {
|
|||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.operandTypes[i].setVariable(varResolverIt->second->name);
|
||||
fmt.operandTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1121,30 +1180,23 @@ LogicalResult FormatParser::parse() {
|
|||
|
||||
void FormatParser::handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver) {
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
// Check to see if this value matches a resolved operand or result type.
|
||||
const NamedTypeConstraint *arg = nullptr;
|
||||
if ((arg = findArg(op.getOperands(), values[i]))) {
|
||||
if (!seenOperandTypes.test(arg - op.operand_begin()))
|
||||
continue;
|
||||
} else if ((arg = findArg(op.getResults(), values[i]))) {
|
||||
if (!seenResultTypes.test(arg - op.result_begin()))
|
||||
continue;
|
||||
} else {
|
||||
const NamedTypeConstraint *arg = findSeenArg(values[i]);
|
||||
if (!arg)
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mark this value as the type resolver for the other variables.
|
||||
for (unsigned j = 0; j != i; ++j)
|
||||
variableTyResolver[values[j]] = arg;
|
||||
variableTyResolver[values[j]] = {arg, llvm::None};
|
||||
for (unsigned j = i + 1; j != e; ++j)
|
||||
variableTyResolver[values[j]] = arg;
|
||||
variableTyResolver[values[j]] = {arg, llvm::None};
|
||||
}
|
||||
}
|
||||
|
||||
void FormatParser::handleSameTypesConstraint(
|
||||
llvm::StringMap<const NamedTypeConstraint *> &variableTyResolver,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
bool includeResults) {
|
||||
const NamedTypeConstraint *resolver = nullptr;
|
||||
int resolvedIt = -1;
|
||||
|
@ -1160,14 +1212,22 @@ void FormatParser::handleSameTypesConstraint(
|
|||
// Set the resolvers for each operand and result.
|
||||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
|
||||
if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
|
||||
variableTyResolver[op.getOperand(i).name] = resolver;
|
||||
variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
|
||||
if (includeResults) {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
|
||||
if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
|
||||
variableTyResolver[op.getResultName(i)] = resolver;
|
||||
variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
|
||||
}
|
||||
}
|
||||
|
||||
const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
|
||||
if (auto *arg = findArg(op.getOperands(), name))
|
||||
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
|
||||
if (auto *arg = findArg(op.getResults(), name))
|
||||
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
|
||||
bool isTopLevel) {
|
||||
// Directives.
|
||||
|
@ -1191,7 +1251,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
|||
StringRef name = varTok.getSpelling().drop_front();
|
||||
llvm::SMLoc loc = varTok.getLoc();
|
||||
|
||||
// Check that the parsed argument is something actually registered on the op.
|
||||
// Check that the parsed argument is something actually registered on the
|
||||
// op.
|
||||
/// Attributes
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
|
||||
if (isTopLevel && !seenAttrs.insert(attr).second)
|
||||
|
|
Loading…
Reference in New Issue