forked from OSchip/llvm-project
[mlir] Update various operations to declaratively specify their assembly format.
Summary: This revision switches over many operations to use the declarative methods for defining the assembly specification. This updates operations in the NVVM, ROCDL, Standard, and VectorOps dialects. Differential Revision: https://reviews.llvm.org/D73407
This commit is contained in:
parent
1c158d0f90
commit
82170d5619
mlir
include/mlir
lib/Dialect
test/Dialect/LLVMIR
|
@ -42,8 +42,7 @@ class NVVM_SpecialRegisterOp<string mnemonic,
|
|||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
||||
# "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
|
||||
let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
|
||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -77,8 +76,7 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
|
|||
string llvmBuilder = [{
|
||||
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
|
||||
}];
|
||||
let parser = [{ return success(); }];
|
||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVVM_ShflBflyOp :
|
||||
|
@ -129,8 +127,7 @@ def NVVM_MmaOp :
|
|||
$res = createIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
|
||||
}];
|
||||
let parser = [{ return parseNVVMMmaOp(parser, result); }];
|
||||
let printer = [{ printNVVMMmaOp(p, *this); }];
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
|
|
|
@ -42,8 +42,7 @@ class ROCDL_SpecialRegisterOp<string mnemonic,
|
|||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
||||
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");";
|
||||
let parser = [{ return parseROCDLOp(parser, result); }];
|
||||
let printer = [{ printROCDLOp(p, this->getOperation()); }];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
|
||||
|
@ -52,8 +51,7 @@ class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
|
|||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
|
||||
# device_function # "\", " # parameter # ");";
|
||||
let parser = [{ return parseROCDLOp(parser, result); }];
|
||||
let printer = [{ printROCDLOp(p, this->getOperation()); }];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -304,6 +304,10 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
|
|||
return getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
|
||||
|
@ -651,6 +655,7 @@ def DeallocOp : Std_Op<"dealloc"> {
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
||||
|
@ -987,6 +992,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
|
|||
}]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def RemFOp : FloatArithmeticOp<"remf"> {
|
||||
|
|
|
@ -203,6 +203,7 @@ def Vector_BroadcastOp :
|
|||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
|
||||
}
|
||||
|
||||
def Vector_ShuffleOp :
|
||||
|
@ -364,6 +365,10 @@ def Vector_ExtractSlicesOp :
|
|||
static StringRef getSizesAttrName() { return "sizes"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into`
|
||||
type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertElementOp :
|
||||
|
@ -482,6 +487,10 @@ def Vector_InsertSlicesOp :
|
|||
static StringRef getSizesAttrName() { return "sizes"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into`
|
||||
type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertStridedSliceOp :
|
||||
|
@ -727,6 +736,7 @@ def Vector_StridedSliceOp :
|
|||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
|
||||
}
|
||||
|
||||
def Vector_TransferReadOp :
|
||||
|
@ -939,6 +949,10 @@ def Vector_TransferWriteOp :
|
|||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$vector `,` $memref `[` $indices `]` attr-dict `:` type($vector) `,`
|
||||
type($memref)
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_TypeCastOp :
|
||||
|
@ -1017,6 +1031,7 @@ def Vector_ConstantMaskOp :
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
|
||||
}];
|
||||
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
|
||||
}
|
||||
|
||||
def Vector_CreateMaskOp :
|
||||
|
@ -1048,6 +1063,7 @@ def Vector_CreateMaskOp :
|
|||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let assemblyFormat = "$operands attr-dict `:` type(results)";
|
||||
}
|
||||
|
||||
def Vector_TupleOp :
|
||||
|
@ -1148,6 +1164,7 @@ def Vector_PrintOp :
|
|||
return source().getType();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source)";
|
||||
}
|
||||
|
||||
#endif // VECTOR_OPS
|
||||
|
|
|
@ -312,7 +312,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
|||
def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
|
||||
|
||||
// Index type.
|
||||
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">;
|
||||
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">,
|
||||
BuildableType<"getIndexType()">;
|
||||
|
||||
// Integer type of a specific width.
|
||||
class I<int width>
|
||||
|
|
|
@ -41,18 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
|
|||
p << " : " << op->getResultTypes();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.nvvm.XYZ` : type
|
||||
static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Type type;
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
result.addTypes(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
|
||||
return parser.getBuilder()
|
||||
.getContext()
|
||||
|
@ -103,41 +91,6 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
|
|||
parser.getNameLoc(), result.operands));
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
|
||||
// : signature_type
|
||||
static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 12> ops;
|
||||
Type type;
|
||||
llvm::SMLoc typeLoc;
|
||||
if (parser.parseOperandList(ops) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto signature = type.dyn_cast<FunctionType>();
|
||||
if (!signature) {
|
||||
return parser.emitError(
|
||||
typeLoc, "expected the type to be the full list of input and output");
|
||||
}
|
||||
|
||||
if (signature.getNumResults() != 1) {
|
||||
return parser.emitError(typeLoc, "expected single result");
|
||||
}
|
||||
|
||||
return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
|
||||
parser.resolveOperands(ops, signature.getInputs(),
|
||||
parser.getNameLoc(), result.operands));
|
||||
}
|
||||
|
||||
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
|
||||
p << op.getOperationName() << " " << op.getOperands();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : "
|
||||
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
|
||||
op.getType(), op.getContext());
|
||||
}
|
||||
|
||||
static LogicalResult verify(MmaOp op) {
|
||||
auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
|
||||
|
|
|
@ -30,24 +30,6 @@
|
|||
using namespace mlir;
|
||||
using namespace ROCDL;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for ROCDL ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
|
||||
p << op->getName() << " " << op->getOperands();
|
||||
if (op->getNumResults() > 0)
|
||||
p << " : " << op->getResultTypes();
|
||||
}
|
||||
|
||||
// <operation> ::= `rocdl.XYZ` : type
|
||||
static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) {
|
||||
Type type;
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ROCDLDialect initialization, type parsing, and registration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -446,29 +446,6 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
||||
FlatSymbolRefAttr calleeAttr;
|
||||
FunctionType calleeType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
auto calleeLoc = parser.getNameLoc();
|
||||
if (parser.parseAttribute(calleeAttr, "callee", result.attributes) ||
|
||||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(calleeType) ||
|
||||
parser.addTypesToList(calleeType.getResults(), result.types) ||
|
||||
parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CallOp op) {
|
||||
p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
p << " : " << op.getCalleeType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CallOp op) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||
|
@ -1184,19 +1161,6 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
|||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
static void print(OpAsmPrinter &p, DeallocOp op) {
|
||||
p << "dealloc " << op.memref() << " : " << op.memref().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType memrefInfo;
|
||||
MemRefType type;
|
||||
|
||||
return failure(parser.parseOperand(memrefInfo) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(memrefInfo, type, result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(DeallocOp op) {
|
||||
if (!op.memref().getType().isa<MemRefType>())
|
||||
return op.emitOpError("operand must be a memref");
|
||||
|
@ -1844,20 +1808,6 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, RankOp op) {
|
||||
p << "rank " << op.getOperand() << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
Type type;
|
||||
Type indexType = parser.getBuilder().getIndexType();
|
||||
return failure(parser.parseOperand(operandInfo) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(operandInfo, type, result.operands) ||
|
||||
parser.addTypeToList(indexType, result.types));
|
||||
}
|
||||
|
||||
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold rank when the rank of the tensor is known.
|
||||
auto type = getOperand().getType();
|
||||
|
|
|
@ -474,38 +474,6 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result,
|
|||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||
}
|
||||
|
||||
static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
ArrayAttr sizesAttr;
|
||||
StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
|
||||
ArrayAttr stridesAttr;
|
||||
StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
|
||||
VectorType vectorType;
|
||||
TupleType resultTupleType;
|
||||
return failure(
|
||||
parser.parseOperand(operandInfo) || parser.parseComma() ||
|
||||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
|
||||
parser.parseComma() ||
|
||||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(vectorType) ||
|
||||
parser.parseKeywordType("into", resultTupleType) ||
|
||||
parser.resolveOperand(operandInfo, vectorType, result.operands) ||
|
||||
parser.addTypeToList(resultTupleType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
|
||||
p << op.getOperationName() << ' ' << op.vector() << ", ";
|
||||
p << op.sizes() << ", " << op.strides();
|
||||
p.printOptionalAttrDict(
|
||||
op.getAttrs(),
|
||||
/*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
|
||||
ExtractSlicesOp::getStridesAttrName()});
|
||||
p << " : " << op.vector().getType();
|
||||
p << " into " << op.getResultTupleType();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
|
||||
TupleType tupleType, ArrayRef<int64_t> sizes,
|
||||
|
@ -572,11 +540,6 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
|
|||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BroadcastOp op) {
|
||||
p << op.getOperationName() << " " << op.source() << " : "
|
||||
<< op.getSourceType() << " to " << op.getVectorType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(BroadcastOp op) {
|
||||
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
||||
VectorType dstVectorType = op.getVectorType();
|
||||
|
@ -601,18 +564,6 @@ static LogicalResult verify(BroadcastOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType source;
|
||||
Type sourceType;
|
||||
VectorType vectorType;
|
||||
return failure(parser.parseOperand(source) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.parseKeywordType("to", vectorType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.addTypeToList(vectorType, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -808,38 +759,6 @@ static LogicalResult verify(InsertOp op) {
|
|||
// InsertSlicesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
ArrayAttr sizesAttr;
|
||||
StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
|
||||
ArrayAttr stridesAttr;
|
||||
StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
|
||||
TupleType tupleType;
|
||||
VectorType resultVectorType;
|
||||
return failure(
|
||||
parser.parseOperand(operandInfo) || parser.parseComma() ||
|
||||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
|
||||
parser.parseComma() ||
|
||||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(tupleType) ||
|
||||
parser.parseKeywordType("into", resultVectorType) ||
|
||||
parser.resolveOperand(operandInfo, tupleType, result.operands) ||
|
||||
parser.addTypeToList(resultVectorType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
|
||||
p << op.getOperationName() << ' ' << op.vectors() << ", ";
|
||||
p << op.sizes() << ", " << op.strides();
|
||||
p.printOptionalAttrDict(
|
||||
op.getAttrs(),
|
||||
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
|
||||
InsertSlicesOp::getStridesAttrName()});
|
||||
p << " : " << op.vectors().getType();
|
||||
p << " into " << op.getResultVectorType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertSlicesOp op) {
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
op.getSizes(sizes);
|
||||
|
@ -1231,27 +1150,6 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
|
|||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, StridedSliceOp op) {
|
||||
p << op.getOperationName() << " " << op.vector();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector().getType() << " to " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc attributeLoc, typeLoc;
|
||||
OpAsmParser::OperandType vector;
|
||||
VectorType vectorType, resultVectorType;
|
||||
return failure(parser.parseOperand(vector) ||
|
||||
parser.getCurrentLocation(&attributeLoc) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typeLoc) ||
|
||||
parser.parseColonType(vectorType) ||
|
||||
parser.parseKeywordType("to", resultVectorType) ||
|
||||
parser.resolveOperand(vector, vectorType, result.operands) ||
|
||||
parser.addTypeToList(resultVectorType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(StridedSliceOp op) {
|
||||
auto type = op.getVectorType();
|
||||
auto offsets = op.offsets();
|
||||
|
@ -1519,35 +1417,6 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// TransferWriteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
|
||||
<< op.indices() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc typesLoc;
|
||||
OpAsmParser::OperandType storeValueInfo;
|
||||
OpAsmParser::OperandType memRefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(memRefInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 2)
|
||||
return parser.emitError(typesLoc, "two types required");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type vectorType = types[0], memRefType = types[1];
|
||||
return failure(
|
||||
parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
|
||||
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransferWriteOp op) {
|
||||
// Consistency of elemental types in memref and vector.
|
||||
|
@ -1676,23 +1545,6 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ConstantMaskOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseConstantMaskOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Type resultType;
|
||||
ArrayAttr maskDimSizesAttr;
|
||||
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
|
||||
return failure(
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
|
||||
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
|
||||
<< op.getResult().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ConstantMaskOp &op) {
|
||||
// Verify that array attr size matches the rank of the vector result.
|
||||
auto resultType = op.getResult().getType().cast<VectorType>();
|
||||
|
@ -1724,23 +1576,6 @@ static LogicalResult verify(ConstantMaskOp &op) {
|
|||
// CreateMaskOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCreateMaskOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type resultType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||
return failure(
|
||||
parser.parseOperandList(operandInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CreateMaskOp op) {
|
||||
p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CreateMaskOp op) {
|
||||
// Verify that an operand was specified for each result vector each dimension.
|
||||
if (op.getNumOperands() !=
|
||||
|
@ -1750,23 +1585,6 @@ static LogicalResult verify(CreateMaskOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrintOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType source;
|
||||
Type sourceType;
|
||||
return failure(parser.parseOperand(source) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PrintOp op) {
|
||||
p << op.getOperationName() << ' ' << op.source() << " : "
|
||||
<< op.getPrintType();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
|
||||
|
|
|
@ -366,7 +366,7 @@ func @nvvm_invalid_mma_6(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
|||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// expected-error@+1 {{expected the type to be the full list of input and output}}
|
||||
// expected-error@+1 {{invalid kind of type specified}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
@ -378,9 +378,9 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
|||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// expected-error@+1 {{expected single result}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
||||
llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
||||
// expected-error@+1 {{op requires one result}}
|
||||
%0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
||||
llvm.return %0#0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue