From 82170d5619987aac0de1f7cc62bdcdc8a68e783c Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 30 Jan 2020 11:31:50 -0800 Subject: [PATCH] [mlir] Update various operations to declaratively specify their assembly format. Summary: This revision switches over many operations to use the declarative methods for defining the assembly specification. This updates operations in the NVVM, ROCDL, Standard, and VectorOps dialects. Differential Revision: https://reviews.llvm.org/D73407 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 9 +- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 6 +- mlir/include/mlir/Dialect/StandardOps/Ops.td | 6 + .../mlir/Dialect/VectorOps/VectorOps.td | 17 ++ mlir/include/mlir/IR/OpBase.td | 3 +- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 47 ----- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 18 -- mlir/lib/Dialect/StandardOps/Ops.cpp | 50 ----- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 182 ------------------ mlir/test/Dialect/LLVMIR/invalid.mlir | 8 +- 10 files changed, 34 insertions(+), 312 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index cc00ade4109a..0b18ef75897f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -42,8 +42,7 @@ class NVVM_SpecialRegisterOp, Arguments<(ins)> { string llvmBuilder = "$res = createIntrinsicCall(builder," # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");"; - let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } //===----------------------------------------------------------------------===// @@ -77,8 +76,7 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { string llvmBuilder = [{ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0); }]; - let parser = [{ return success(); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict"; } def NVVM_ShflBflyOp : @@ -129,8 +127,7 @@ def NVVM_MmaOp : $res = createIntrinsicCall( builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args); }]; - let parser = [{ return parseNVVMMmaOp(parser, result); }]; - let printer = [{ printNVVMMmaOp(p, *this); }]; + let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index d16281151ac0..c13b02783635 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -42,8 +42,7 @@ class ROCDL_SpecialRegisterOp, Arguments<(ins)> { string llvmBuilder = "$res = createIntrinsicCall(builder," # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");"; - let parser = [{ return parseROCDLOp(parser, result); }]; - let printer = [{ printROCDLOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } class ROCDL_DeviceFunctionOp, Arguments<(ins)> { string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" # device_function # "\", " # parameter # ");"; - let parser = [{ return parseROCDLOp(parser, result); }]; - let printer = [{ printROCDLOp(p, this->getOperation()); }]; + let assemblyFormat = "attr-dict `:` type($res)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 6b49eed92451..a7647edd9909 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -304,6 +304,10 @@ def CallOp : Std_Op<"call", [CallOpInterface]> { return getAttrOfType("callee"); } }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; } def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> { @@ -651,6 +655,7 @@ def DeallocOp : Std_Op<"dealloc"> { let hasCanonicalizer = 1; let hasFolder = 1; + let assemblyFormat = "$memref attr-dict `:` type($memref)"; } def DimOp : Std_Op<"dim", [NoSideEffect]> { @@ -987,6 +992,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> { }]>]; let hasFolder = 1; + let assemblyFormat = "operands attr-dict `:` type(operands)"; } def RemFOp : FloatArithmeticOp<"remf"> { diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 847069ce4ea2..655158a623a6 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -203,6 +203,7 @@ def Vector_BroadcastOp : return vector().getType().cast(); } }]; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; } def Vector_ShuffleOp : @@ -364,6 +365,10 @@ def Vector_ExtractSlicesOp : static StringRef getSizesAttrName() { return "sizes"; } static StringRef getStridesAttrName() { return "strides"; } }]; + let assemblyFormat = [{ + $vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into` + type(results) + }]; } def Vector_InsertElementOp : @@ -482,6 +487,10 @@ def Vector_InsertSlicesOp : static StringRef getSizesAttrName() { return "sizes"; } static StringRef getStridesAttrName() { return "strides"; } }]; + let assemblyFormat = [{ + $vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into` + type(results) + }]; } def Vector_InsertStridedSliceOp : @@ -727,6 +736,7 @@ def Vector_StridedSliceOp : void getOffsets(SmallVectorImpl &results); }]; let hasCanonicalizer = 1; + let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } def Vector_TransferReadOp : @@ -939,6 +949,10 @@ def Vector_TransferWriteOp : return memref().getType().cast(); } }]; + let assemblyFormat = [{ + $vector `,` $memref `[` $indices `]` attr-dict `:` type($vector) `,` + type($memref) + }]; } def Vector_TypeCastOp : @@ -1017,6 +1031,7 @@ def Vector_ConstantMaskOp : let extraClassDeclaration = [{ static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } }]; + let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; } def Vector_CreateMaskOp : @@ -1048,6 +1063,7 @@ def Vector_CreateMaskOp : }]; let hasCanonicalizer = 1; + let assemblyFormat = "$operands attr-dict `:` type(results)"; } def Vector_TupleOp : @@ -1148,6 +1164,7 @@ def Vector_PrintOp : return source().getType(); } }]; + let assemblyFormat = "$source attr-dict `:` type($source)"; } #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 8dd87c6630e2..fd75a2ce66c1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -312,7 +312,8 @@ class AnyTypeOf allowedTypes, string description = ""> : Type< def AnyInteger : Type()">, "integer">; // Index type. -def Index : Type()">, "index">; +def Index : Type()">, "index">, + BuildableType<"getIndexType()">; // Integer type of a specific width. class I diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 9900875392c0..b7b32df12ca6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -41,18 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { p << " : " << op->getResultTypes(); } -// ::= `llvm.nvvm.XYZ` : type -static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser, - OperationState &result) { - Type type; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - result.addTypes(type); - return success(); -} - static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { return parser.getBuilder() .getContext() @@ -103,41 +91,6 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, parser.getNameLoc(), result.operands)); } -// ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...` -// : signature_type -static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) { - SmallVector ops; - Type type; - llvm::SMLoc typeLoc; - if (parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) { - return failure(); - } - - auto signature = type.dyn_cast(); - if (!signature) { - return parser.emitError( - typeLoc, "expected the type to be the full list of input and output"); - } - - if (signature.getNumResults() != 1) { - return parser.emitError(typeLoc, "expected single result"); - } - - return failure(parser.addTypeToList(signature.getResult(0), result.types) || - parser.resolveOperands(ops, signature.getInputs(), - parser.getNameLoc(), result.operands)); -} - -static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) { - p << op.getOperationName() << " " << op.getOperands(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " - << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()), - op.getType(), op.getContext()); -} - static LogicalResult verify(MmaOp op) { auto dialect = op.getContext()->getRegisteredDialect(); auto f16Ty = LLVM::LLVMType::getHalfTy(dialect); diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 4d6005dc8fad..28ee28f71e86 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -30,24 +30,6 @@ using namespace mlir; using namespace ROCDL; -//===----------------------------------------------------------------------===// -// Printing/parsing for ROCDL ops -//===----------------------------------------------------------------------===// - -static void printROCDLOp(OpAsmPrinter &p, Operation *op) { - p << op->getName() << " " << op->getOperands(); - if (op->getNumResults() > 0) - p << " : " << op->getResultTypes(); -} - -// ::= `rocdl.XYZ` : type -static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) { - Type type; - return failure(parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.addTypeToList(type, result.types)); -} - //===----------------------------------------------------------------------===// // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 824a2ea87d96..fded6082273c 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -446,29 +446,6 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // CallOp //===----------------------------------------------------------------------===// -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { - FlatSymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", result.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), result.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - result.operands)) - return failure(); - - return success(); -} - -static void print(OpAsmPrinter &p, CallOp op) { - p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCalleeType(); -} - static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. auto fnAttr = op.getAttrOfType("callee"); @@ -1184,19 +1161,6 @@ struct SimplifyDeadDealloc : public OpRewritePattern { }; } // end anonymous namespace. -static void print(OpAsmPrinter &p, DeallocOp op) { - p << "dealloc " << op.memref() << " : " << op.memref().getType(); -} - -static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType memrefInfo; - MemRefType type; - - return failure(parser.parseOperand(memrefInfo) || - parser.parseColonType(type) || - parser.resolveOperand(memrefInfo, type, result.operands)); -} - static LogicalResult verify(DeallocOp op) { if (!op.memref().getType().isa()) return op.emitOpError("operand must be a memref"); @@ -1844,20 +1808,6 @@ LogicalResult PrefetchOp::fold(ArrayRef cstOperands, // RankOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, RankOp op) { - p << "rank " << op.getOperand() << " : " << op.getOperand().getType(); -} - -static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType operandInfo; - Type type; - Type indexType = parser.getBuilder().getIndexType(); - return failure(parser.parseOperand(operandInfo) || - parser.parseColonType(type) || - parser.resolveOperand(operandInfo, type, result.operands) || - parser.addTypeToList(indexType, result.types)); -} - OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the tensor is known. auto type = getOperand().getType(); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 40092e250e04..94cce5e12662 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -474,38 +474,6 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result, result.addAttribute(getStridesAttrName(), stridesAttr); } -static ParseResult parseExtractSlicesOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operandInfo; - ArrayAttr sizesAttr; - StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName(); - ArrayAttr stridesAttr; - StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName(); - VectorType vectorType; - TupleType resultTupleType; - return failure( - parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) || - parser.parseComma() || - parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(vectorType) || - parser.parseKeywordType("into", resultTupleType) || - parser.resolveOperand(operandInfo, vectorType, result.operands) || - parser.addTypeToList(resultTupleType, result.types)); -} - -static void print(OpAsmPrinter &p, ExtractSlicesOp op) { - p << op.getOperationName() << ' ' << op.vector() << ", "; - p << op.sizes() << ", " << op.strides(); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(), - ExtractSlicesOp::getStridesAttrName()}); - p << " : " << op.vector().getType(); - p << " into " << op.getResultTupleType(); -} - static LogicalResult isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType, TupleType tupleType, ArrayRef sizes, @@ -572,11 +540,6 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl &results) { // BroadcastOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << op.source() << " : " - << op.getSourceType() << " to " << op.getVectorType(); -} - static LogicalResult verify(BroadcastOp op) { VectorType srcVectorType = op.getSourceType().dyn_cast(); VectorType dstVectorType = op.getVectorType(); @@ -601,18 +564,6 @@ static LogicalResult verify(BroadcastOp op) { return success(); } -static ParseResult parseBroadcastOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType source; - Type sourceType; - VectorType vectorType; - return failure(parser.parseOperand(source) || - parser.parseColonType(sourceType) || - parser.parseKeywordType("to", vectorType) || - parser.resolveOperand(source, sourceType, result.operands) || - parser.addTypeToList(vectorType, result.types)); -} - //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// @@ -808,38 +759,6 @@ static LogicalResult verify(InsertOp op) { // InsertSlicesOp //===----------------------------------------------------------------------===// -static ParseResult parseInsertSlicesOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operandInfo; - ArrayAttr sizesAttr; - StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName(); - ArrayAttr stridesAttr; - StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName(); - TupleType tupleType; - VectorType resultVectorType; - return failure( - parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) || - parser.parseComma() || - parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(tupleType) || - parser.parseKeywordType("into", resultVectorType) || - parser.resolveOperand(operandInfo, tupleType, result.operands) || - parser.addTypeToList(resultVectorType, result.types)); -} - -static void print(OpAsmPrinter &p, InsertSlicesOp op) { - p << op.getOperationName() << ' ' << op.vectors() << ", "; - p << op.sizes() << ", " << op.strides(); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(), - InsertSlicesOp::getStridesAttrName()}); - p << " : " << op.vectors().getType(); - p << " into " << op.getResultVectorType(); -} - static LogicalResult verify(InsertSlicesOp op) { SmallVector sizes; op.getSizes(sizes); @@ -1231,27 +1150,6 @@ void StridedSliceOp::build(Builder *builder, OperationState &result, result.addAttribute(getStridesAttrName(), stridesAttr); } -static void print(OpAsmPrinter &p, StridedSliceOp op) { - p << op.getOperationName() << " " << op.vector(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.vector().getType() << " to " << op.getResult().getType(); -} - -static ParseResult parseStridedSliceOp(OpAsmParser &parser, - OperationState &result) { - llvm::SMLoc attributeLoc, typeLoc; - OpAsmParser::OperandType vector; - VectorType vectorType, resultVectorType; - return failure(parser.parseOperand(vector) || - parser.getCurrentLocation(&attributeLoc) || - parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typeLoc) || - parser.parseColonType(vectorType) || - parser.parseKeywordType("to", resultVectorType) || - parser.resolveOperand(vector, vectorType, result.operands) || - parser.addTypeToList(resultVectorType, result.types)); -} - static LogicalResult verify(StridedSliceOp op) { auto type = op.getVectorType(); auto offsets = op.offsets(); @@ -1519,35 +1417,6 @@ static LogicalResult verify(TransferReadOp op) { //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "[" - << op.indices() << "]"; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getVectorType() << ", " << op.getMemRefType(); -} - -static ParseResult parseTransferWriteOp(OpAsmParser &parser, - OperationState &result) { - llvm::SMLoc typesLoc; - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memRefInfo; - SmallVector indexInfo; - SmallVector types; - if (parser.parseOperand(storeValueInfo) || parser.parseComma() || - parser.parseOperand(memRefInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) - return failure(); - if (types.size() != 2) - return parser.emitError(typesLoc, "two types required"); - auto indexType = parser.getBuilder().getIndexType(); - Type vectorType = types[0], memRefType = types[1]; - return failure( - parser.resolveOperand(storeValueInfo, vectorType, result.operands) || - parser.resolveOperand(memRefInfo, memRefType, result.operands) || - parser.resolveOperands(indexInfo, indexType, result.operands)); -} static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. @@ -1676,23 +1545,6 @@ OpFoldResult TupleGetOp::fold(ArrayRef operands) { // ConstantMaskOp //===----------------------------------------------------------------------===// -static ParseResult parseConstantMaskOp(OpAsmParser &parser, - OperationState &result) { - Type resultType; - ArrayAttr maskDimSizesAttr; - StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName(); - return failure( - parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) || - parser.parseColonType(resultType) || - parser.addTypeToList(resultType, result.types)); -} - -static void print(OpAsmPrinter &p, ConstantMaskOp op) { - p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : " - << op.getResult().getType(); -} - static LogicalResult verify(ConstantMaskOp &op) { // Verify that array attr size matches the rank of the vector result. auto resultType = op.getResult().getType().cast(); @@ -1724,23 +1576,6 @@ static LogicalResult verify(ConstantMaskOp &op) { // CreateMaskOp //===----------------------------------------------------------------------===// -static ParseResult parseCreateMaskOp(OpAsmParser &parser, - OperationState &result) { - auto indexType = parser.getBuilder().getIndexType(); - Type resultType; - SmallVector operandInfo; - return failure( - parser.parseOperandList(operandInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType) || - parser.resolveOperands(operandInfo, indexType, result.operands) || - parser.addTypeToList(resultType, result.types)); -} - -static void print(OpAsmPrinter &p, CreateMaskOp op) { - p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType(); -} - static LogicalResult verify(CreateMaskOp op) { // Verify that an operand was specified for each result vector each dimension. if (op.getNumOperands() != @@ -1750,23 +1585,6 @@ static LogicalResult verify(CreateMaskOp op) { return success(); } -//===----------------------------------------------------------------------===// -// PrintOp -//===----------------------------------------------------------------------===// - -static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType source; - Type sourceType; - return failure(parser.parseOperand(source) || - parser.parseColonType(sourceType) || - parser.resolveOperand(source, sourceType, result.operands)); -} - -static void print(OpAsmPrinter &p, PrintOp op) { - p << op.getOperationName() << ' ' << op.source() << " : " - << op.getPrintType(); -} - namespace { // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 153248d8bb82..e0efa92a93ca 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -366,7 +366,7 @@ func @nvvm_invalid_mma_6(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">, %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">, %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float, %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) { - // expected-error@+1 {{expected the type to be the full list of input and output}} + // expected-error@+1 {{invalid kind of type specified}} %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }"> llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }"> } @@ -378,9 +378,9 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">, %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">, %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float, %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) { - // expected-error@+1 {{expected single result}} - %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) - llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) + // expected-error@+1 {{op requires one result}} + %0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) + llvm.return %0#0 : !llvm<"{ float, float, float, float, float, float, float, float }"> } // -----