forked from OSchip/llvm-project
[mlir] Update various operations to declaratively specify their assembly format.
Summary: This revision switches over many operations to use the declarative methods for defining the assembly specification. This updates operations in the NVVM, ROCDL, Standard, and VectorOps dialects. Differential Revision: https://reviews.llvm.org/D73407
This commit is contained in:
parent
1c158d0f90
commit
82170d5619
|
@ -42,8 +42,7 @@ class NVVM_SpecialRegisterOp<string mnemonic,
|
||||||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||||
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
||||||
# "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
|
# "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
|
||||||
let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
|
let assemblyFormat = "attr-dict `:` type($res)";
|
||||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -77,8 +76,7 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
|
||||||
string llvmBuilder = [{
|
string llvmBuilder = [{
|
||||||
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
|
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
|
||||||
}];
|
}];
|
||||||
let parser = [{ return success(); }];
|
let assemblyFormat = "attr-dict";
|
||||||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def NVVM_ShflBflyOp :
|
def NVVM_ShflBflyOp :
|
||||||
|
@ -129,8 +127,7 @@ def NVVM_MmaOp :
|
||||||
$res = createIntrinsicCall(
|
$res = createIntrinsicCall(
|
||||||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
|
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
|
||||||
}];
|
}];
|
||||||
let parser = [{ return parseNVVMMmaOp(parser, result); }];
|
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||||
let printer = [{ printNVVMMmaOp(p, *this); }];
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,7 @@ class ROCDL_SpecialRegisterOp<string mnemonic,
|
||||||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||||
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
string llvmBuilder = "$res = createIntrinsicCall(builder,"
|
||||||
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");";
|
# "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");";
|
||||||
let parser = [{ return parseROCDLOp(parser, result); }];
|
let assemblyFormat = "attr-dict `:` type($res)";
|
||||||
let printer = [{ printROCDLOp(p, this->getOperation()); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
|
class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
|
||||||
|
@ -52,8 +51,7 @@ class ROCDL_DeviceFunctionOp<string mnemonic, string device_function,
|
||||||
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
|
||||||
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
|
string llvmBuilder = "$res = createDeviceFunctionCall(builder, \""
|
||||||
# device_function # "\", " # parameter # ");";
|
# device_function # "\", " # parameter # ");";
|
||||||
let parser = [{ return parseROCDLOp(parser, result); }];
|
let assemblyFormat = "attr-dict `:` type($res)";
|
||||||
let printer = [{ printROCDLOp(p, this->getOperation()); }];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -304,6 +304,10 @@ def CallOp : Std_Op<"call", [CallOpInterface]> {
|
||||||
return getAttrOfType<SymbolRefAttr>("callee");
|
return getAttrOfType<SymbolRefAttr>("callee");
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
|
def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
|
||||||
|
@ -651,6 +655,7 @@ def DeallocOp : Std_Op<"dealloc"> {
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
||||||
|
@ -987,6 +992,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
|
||||||
}]>];
|
}]>];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def RemFOp : FloatArithmeticOp<"remf"> {
|
def RemFOp : FloatArithmeticOp<"remf"> {
|
||||||
|
|
|
@ -203,6 +203,7 @@ def Vector_BroadcastOp :
|
||||||
return vector().getType().cast<VectorType>();
|
return vector().getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_ShuffleOp :
|
def Vector_ShuffleOp :
|
||||||
|
@ -364,6 +365,10 @@ def Vector_ExtractSlicesOp :
|
||||||
static StringRef getSizesAttrName() { return "sizes"; }
|
static StringRef getSizesAttrName() { return "sizes"; }
|
||||||
static StringRef getStridesAttrName() { return "strides"; }
|
static StringRef getStridesAttrName() { return "strides"; }
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$vector `,` $sizes `,` $strides attr-dict `:` type($vector) `into`
|
||||||
|
type(results)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_InsertElementOp :
|
def Vector_InsertElementOp :
|
||||||
|
@ -482,6 +487,10 @@ def Vector_InsertSlicesOp :
|
||||||
static StringRef getSizesAttrName() { return "sizes"; }
|
static StringRef getSizesAttrName() { return "sizes"; }
|
||||||
static StringRef getStridesAttrName() { return "strides"; }
|
static StringRef getStridesAttrName() { return "strides"; }
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$vectors `,` $sizes `,` $strides attr-dict `:` type($vectors) `into`
|
||||||
|
type(results)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_InsertStridedSliceOp :
|
def Vector_InsertStridedSliceOp :
|
||||||
|
@ -727,6 +736,7 @@ def Vector_StridedSliceOp :
|
||||||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||||
}];
|
}];
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_TransferReadOp :
|
def Vector_TransferReadOp :
|
||||||
|
@ -939,6 +949,10 @@ def Vector_TransferWriteOp :
|
||||||
return memref().getType().cast<MemRefType>();
|
return memref().getType().cast<MemRefType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$vector `,` $memref `[` $indices `]` attr-dict `:` type($vector) `,`
|
||||||
|
type($memref)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_TypeCastOp :
|
def Vector_TypeCastOp :
|
||||||
|
@ -1017,6 +1031,7 @@ def Vector_ConstantMaskOp :
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
|
static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; }
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_CreateMaskOp :
|
def Vector_CreateMaskOp :
|
||||||
|
@ -1048,6 +1063,7 @@ def Vector_CreateMaskOp :
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let assemblyFormat = "$operands attr-dict `:` type(results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_TupleOp :
|
def Vector_TupleOp :
|
||||||
|
@ -1148,6 +1164,7 @@ def Vector_PrintOp :
|
||||||
return source().getType();
|
return source().getType();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let assemblyFormat = "$source attr-dict `:` type($source)";
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // VECTOR_OPS
|
#endif // VECTOR_OPS
|
||||||
|
|
|
@ -312,7 +312,8 @@ class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
||||||
def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
|
def AnyInteger : Type<CPred<"$_self.isa<IntegerType>()">, "integer">;
|
||||||
|
|
||||||
// Index type.
|
// Index type.
|
||||||
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">;
|
def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">,
|
||||||
|
BuildableType<"getIndexType()">;
|
||||||
|
|
||||||
// Integer type of a specific width.
|
// Integer type of a specific width.
|
||||||
class I<int width>
|
class I<int width>
|
||||||
|
|
|
@ -41,18 +41,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
|
||||||
p << " : " << op->getResultTypes();
|
p << " : " << op->getResultTypes();
|
||||||
}
|
}
|
||||||
|
|
||||||
// <operation> ::= `llvm.nvvm.XYZ` : type
|
|
||||||
static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
Type type;
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(type))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
result.addTypes(type);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
|
static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
|
||||||
return parser.getBuilder()
|
return parser.getBuilder()
|
||||||
.getContext()
|
.getContext()
|
||||||
|
@ -103,41 +91,6 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
|
||||||
parser.getNameLoc(), result.operands));
|
parser.getNameLoc(), result.operands));
|
||||||
}
|
}
|
||||||
|
|
||||||
// <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
|
|
||||||
// : signature_type
|
|
||||||
static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
SmallVector<OpAsmParser::OperandType, 12> ops;
|
|
||||||
Type type;
|
|
||||||
llvm::SMLoc typeLoc;
|
|
||||||
if (parser.parseOperandList(ops) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
||||||
parser.getCurrentLocation(&typeLoc) || parser.parseType(type)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto signature = type.dyn_cast<FunctionType>();
|
|
||||||
if (!signature) {
|
|
||||||
return parser.emitError(
|
|
||||||
typeLoc, "expected the type to be the full list of input and output");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (signature.getNumResults() != 1) {
|
|
||||||
return parser.emitError(typeLoc, "expected single result");
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
|
|
||||||
parser.resolveOperands(ops, signature.getInputs(),
|
|
||||||
parser.getNameLoc(), result.operands));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
|
|
||||||
p << op.getOperationName() << " " << op.getOperands();
|
|
||||||
p.printOptionalAttrDict(op.getAttrs());
|
|
||||||
p << " : "
|
|
||||||
<< FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
|
|
||||||
op.getType(), op.getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(MmaOp op) {
|
static LogicalResult verify(MmaOp op) {
|
||||||
auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
|
auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
|
||||||
|
|
|
@ -30,24 +30,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace ROCDL;
|
using namespace ROCDL;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Printing/parsing for ROCDL ops
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static void printROCDLOp(OpAsmPrinter &p, Operation *op) {
|
|
||||||
p << op->getName() << " " << op->getOperands();
|
|
||||||
if (op->getNumResults() > 0)
|
|
||||||
p << " : " << op->getResultTypes();
|
|
||||||
}
|
|
||||||
|
|
||||||
// <operation> ::= `rocdl.XYZ` : type
|
|
||||||
static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
Type type;
|
|
||||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(type) ||
|
|
||||||
parser.addTypeToList(type, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ROCDLDialect initialization, type parsing, and registration.
|
// ROCDLDialect initialization, type parsing, and registration.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -446,29 +446,6 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
// CallOp
|
// CallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
FlatSymbolRefAttr calleeAttr;
|
|
||||||
FunctionType calleeType;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
|
||||||
auto calleeLoc = parser.getNameLoc();
|
|
||||||
if (parser.parseAttribute(calleeAttr, "callee", result.attributes) ||
|
|
||||||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(calleeType) ||
|
|
||||||
parser.addTypesToList(calleeType.getResults(), result.types) ||
|
|
||||||
parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
|
||||||
result.operands))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, CallOp op) {
|
|
||||||
p << "call " << op.getAttr("callee") << '(' << op.getOperands() << ')';
|
|
||||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
|
||||||
p << " : " << op.getCalleeType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(CallOp op) {
|
static LogicalResult verify(CallOp op) {
|
||||||
// Check that the callee attribute was specified.
|
// Check that the callee attribute was specified.
|
||||||
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
|
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||||
|
@ -1184,19 +1161,6 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, DeallocOp op) {
|
|
||||||
p << "dealloc " << op.memref() << " : " << op.memref().getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
OpAsmParser::OperandType memrefInfo;
|
|
||||||
MemRefType type;
|
|
||||||
|
|
||||||
return failure(parser.parseOperand(memrefInfo) ||
|
|
||||||
parser.parseColonType(type) ||
|
|
||||||
parser.resolveOperand(memrefInfo, type, result.operands));
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(DeallocOp op) {
|
static LogicalResult verify(DeallocOp op) {
|
||||||
if (!op.memref().getType().isa<MemRefType>())
|
if (!op.memref().getType().isa<MemRefType>())
|
||||||
return op.emitOpError("operand must be a memref");
|
return op.emitOpError("operand must be a memref");
|
||||||
|
@ -1844,20 +1808,6 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
|
||||||
// RankOp
|
// RankOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, RankOp op) {
|
|
||||||
p << "rank " << op.getOperand() << " : " << op.getOperand().getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
OpAsmParser::OperandType operandInfo;
|
|
||||||
Type type;
|
|
||||||
Type indexType = parser.getBuilder().getIndexType();
|
|
||||||
return failure(parser.parseOperand(operandInfo) ||
|
|
||||||
parser.parseColonType(type) ||
|
|
||||||
parser.resolveOperand(operandInfo, type, result.operands) ||
|
|
||||||
parser.addTypeToList(indexType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// Constant fold rank when the rank of the tensor is known.
|
// Constant fold rank when the rank of the tensor is known.
|
||||||
auto type = getOperand().getType();
|
auto type = getOperand().getType();
|
||||||
|
|
|
@ -474,38 +474,6 @@ void ExtractSlicesOp::build(Builder *builder, OperationState &result,
|
||||||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
OpAsmParser::OperandType operandInfo;
|
|
||||||
ArrayAttr sizesAttr;
|
|
||||||
StringRef sizesAttrName = ExtractSlicesOp::getSizesAttrName();
|
|
||||||
ArrayAttr stridesAttr;
|
|
||||||
StringRef stridesAttrName = ExtractSlicesOp::getStridesAttrName();
|
|
||||||
VectorType vectorType;
|
|
||||||
TupleType resultTupleType;
|
|
||||||
return failure(
|
|
||||||
parser.parseOperand(operandInfo) || parser.parseComma() ||
|
|
||||||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
|
|
||||||
parser.parseComma() ||
|
|
||||||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(vectorType) ||
|
|
||||||
parser.parseKeywordType("into", resultTupleType) ||
|
|
||||||
parser.resolveOperand(operandInfo, vectorType, result.operands) ||
|
|
||||||
parser.addTypeToList(resultTupleType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.vector() << ", ";
|
|
||||||
p << op.sizes() << ", " << op.strides();
|
|
||||||
p.printOptionalAttrDict(
|
|
||||||
op.getAttrs(),
|
|
||||||
/*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
|
|
||||||
ExtractSlicesOp::getStridesAttrName()});
|
|
||||||
p << " : " << op.vector().getType();
|
|
||||||
p << " into " << op.getResultTupleType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
|
isValidExtractOrInsertSlicesType(Operation *op, VectorType vectorType,
|
||||||
TupleType tupleType, ArrayRef<int64_t> sizes,
|
TupleType tupleType, ArrayRef<int64_t> sizes,
|
||||||
|
@ -572,11 +540,6 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
|
||||||
// BroadcastOp
|
// BroadcastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, BroadcastOp op) {
|
|
||||||
p << op.getOperationName() << " " << op.source() << " : "
|
|
||||||
<< op.getSourceType() << " to " << op.getVectorType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(BroadcastOp op) {
|
static LogicalResult verify(BroadcastOp op) {
|
||||||
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
||||||
VectorType dstVectorType = op.getVectorType();
|
VectorType dstVectorType = op.getVectorType();
|
||||||
|
@ -601,18 +564,6 @@ static LogicalResult verify(BroadcastOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
OpAsmParser::OperandType source;
|
|
||||||
Type sourceType;
|
|
||||||
VectorType vectorType;
|
|
||||||
return failure(parser.parseOperand(source) ||
|
|
||||||
parser.parseColonType(sourceType) ||
|
|
||||||
parser.parseKeywordType("to", vectorType) ||
|
|
||||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
|
||||||
parser.addTypeToList(vectorType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ShuffleOp
|
// ShuffleOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -808,38 +759,6 @@ static LogicalResult verify(InsertOp op) {
|
||||||
// InsertSlicesOp
|
// InsertSlicesOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
OpAsmParser::OperandType operandInfo;
|
|
||||||
ArrayAttr sizesAttr;
|
|
||||||
StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
|
|
||||||
ArrayAttr stridesAttr;
|
|
||||||
StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
|
|
||||||
TupleType tupleType;
|
|
||||||
VectorType resultVectorType;
|
|
||||||
return failure(
|
|
||||||
parser.parseOperand(operandInfo) || parser.parseComma() ||
|
|
||||||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
|
|
||||||
parser.parseComma() ||
|
|
||||||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(tupleType) ||
|
|
||||||
parser.parseKeywordType("into", resultVectorType) ||
|
|
||||||
parser.resolveOperand(operandInfo, tupleType, result.operands) ||
|
|
||||||
parser.addTypeToList(resultVectorType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.vectors() << ", ";
|
|
||||||
p << op.sizes() << ", " << op.strides();
|
|
||||||
p.printOptionalAttrDict(
|
|
||||||
op.getAttrs(),
|
|
||||||
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
|
|
||||||
InsertSlicesOp::getStridesAttrName()});
|
|
||||||
p << " : " << op.vectors().getType();
|
|
||||||
p << " into " << op.getResultVectorType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(InsertSlicesOp op) {
|
static LogicalResult verify(InsertSlicesOp op) {
|
||||||
SmallVector<int64_t, 4> sizes;
|
SmallVector<int64_t, 4> sizes;
|
||||||
op.getSizes(sizes);
|
op.getSizes(sizes);
|
||||||
|
@ -1231,27 +1150,6 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
|
||||||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, StridedSliceOp op) {
|
|
||||||
p << op.getOperationName() << " " << op.vector();
|
|
||||||
p.printOptionalAttrDict(op.getAttrs());
|
|
||||||
p << " : " << op.vector().getType() << " to " << op.getResult().getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
llvm::SMLoc attributeLoc, typeLoc;
|
|
||||||
OpAsmParser::OperandType vector;
|
|
||||||
VectorType vectorType, resultVectorType;
|
|
||||||
return failure(parser.parseOperand(vector) ||
|
|
||||||
parser.getCurrentLocation(&attributeLoc) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.getCurrentLocation(&typeLoc) ||
|
|
||||||
parser.parseColonType(vectorType) ||
|
|
||||||
parser.parseKeywordType("to", resultVectorType) ||
|
|
||||||
parser.resolveOperand(vector, vectorType, result.operands) ||
|
|
||||||
parser.addTypeToList(resultVectorType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(StridedSliceOp op) {
|
static LogicalResult verify(StridedSliceOp op) {
|
||||||
auto type = op.getVectorType();
|
auto type = op.getVectorType();
|
||||||
auto offsets = op.offsets();
|
auto offsets = op.offsets();
|
||||||
|
@ -1519,35 +1417,6 @@ static LogicalResult verify(TransferReadOp op) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransferWriteOp
|
// TransferWriteOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
|
||||||
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
|
|
||||||
<< op.indices() << "]";
|
|
||||||
p.printOptionalAttrDict(op.getAttrs());
|
|
||||||
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
llvm::SMLoc typesLoc;
|
|
||||||
OpAsmParser::OperandType storeValueInfo;
|
|
||||||
OpAsmParser::OperandType memRefInfo;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
||||||
SmallVector<Type, 2> types;
|
|
||||||
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
|
||||||
parser.parseOperand(memRefInfo) ||
|
|
||||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
|
||||||
return failure();
|
|
||||||
if (types.size() != 2)
|
|
||||||
return parser.emitError(typesLoc, "two types required");
|
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
|
||||||
Type vectorType = types[0], memRefType = types[1];
|
|
||||||
return failure(
|
|
||||||
parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
|
|
||||||
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
|
|
||||||
parser.resolveOperands(indexInfo, indexType, result.operands));
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(TransferWriteOp op) {
|
static LogicalResult verify(TransferWriteOp op) {
|
||||||
// Consistency of elemental types in memref and vector.
|
// Consistency of elemental types in memref and vector.
|
||||||
|
@ -1676,23 +1545,6 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// ConstantMaskOp
|
// ConstantMaskOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static ParseResult parseConstantMaskOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
Type resultType;
|
|
||||||
ArrayAttr maskDimSizesAttr;
|
|
||||||
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
|
|
||||||
return failure(
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
|
|
||||||
parser.parseColonType(resultType) ||
|
|
||||||
parser.addTypeToList(resultType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
|
|
||||||
<< op.getResult().getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(ConstantMaskOp &op) {
|
static LogicalResult verify(ConstantMaskOp &op) {
|
||||||
// Verify that array attr size matches the rank of the vector result.
|
// Verify that array attr size matches the rank of the vector result.
|
||||||
auto resultType = op.getResult().getType().cast<VectorType>();
|
auto resultType = op.getResult().getType().cast<VectorType>();
|
||||||
|
@ -1724,23 +1576,6 @@ static LogicalResult verify(ConstantMaskOp &op) {
|
||||||
// CreateMaskOp
|
// CreateMaskOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static ParseResult parseCreateMaskOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
|
||||||
Type resultType;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
|
||||||
return failure(
|
|
||||||
parser.parseOperandList(operandInfo) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(resultType) ||
|
|
||||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
|
||||||
parser.addTypeToList(resultType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, CreateMaskOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.operands() << " : " << op.getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(CreateMaskOp op) {
|
static LogicalResult verify(CreateMaskOp op) {
|
||||||
// Verify that an operand was specified for each result vector each dimension.
|
// Verify that an operand was specified for each result vector each dimension.
|
||||||
if (op.getNumOperands() !=
|
if (op.getNumOperands() !=
|
||||||
|
@ -1750,23 +1585,6 @@ static LogicalResult verify(CreateMaskOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// PrintOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
OpAsmParser::OperandType source;
|
|
||||||
Type sourceType;
|
|
||||||
return failure(parser.parseOperand(source) ||
|
|
||||||
parser.parseColonType(sourceType) ||
|
|
||||||
parser.resolveOperand(source, sourceType, result.operands));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, PrintOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.source() << " : "
|
|
||||||
<< op.getPrintType();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
|
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
|
||||||
|
|
|
@ -366,7 +366,7 @@ func @nvvm_invalid_mma_6(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||||
// expected-error@+1 {{expected the type to be the full list of input and output}}
|
// expected-error@+1 {{invalid kind of type specified}}
|
||||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }">
|
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||||
}
|
}
|
||||||
|
@ -378,9 +378,9 @@ func @nvvm_invalid_mma_7(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||||
// expected-error@+1 {{expected single result}}
|
// expected-error@+1 {{op requires one result}}
|
||||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
%0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
||||||
llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32)
|
llvm.return %0#0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue