forked from OSchip/llvm-project
[mlir][NFC] Use declarative format for several operations in LLVM and Linalg dialects
Differential Revision: https://reviews.llvm.org/D73503
This commit is contained in:
parent
82170d5619
commit
528adb2e48
|
@ -269,8 +269,9 @@ def LLVM_AllocaOp :
|
|||
def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
|
||||
Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
|
||||
LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> {
|
||||
let parser = [{ return parseGEPOp(parser, result); }];
|
||||
let printer = [{ printGEPOp(p, *this); }];
|
||||
let assemblyFormat = [{
|
||||
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
|
||||
LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
|
||||
|
@ -541,8 +542,7 @@ def LLVM_AddressOfOp
|
|||
GlobalOp getGlobal();
|
||||
}];
|
||||
|
||||
let printer = "printAddressOfOp(p, *this);";
|
||||
let parser = "return parseAddressOfOp(parser, result);";
|
||||
let assemblyFormat = "$global_name attr-dict `:` type($res)";
|
||||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
||||
|
@ -659,23 +659,20 @@ def LLVM_NullOp
|
|||
: LLVM_OneResultOp<"mlir.null", [NoSideEffect]>,
|
||||
LLVM_Builder<"$res = llvm::ConstantPointerNull::get("
|
||||
" cast<llvm::PointerType>($_resultType));"> {
|
||||
let parser = [{ return parseNullOp(parser, result); }];
|
||||
let printer = [{ printNullOp(p, *this); }];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def LLVM_UndefOp : LLVM_OneResultOp<"mlir.undef", [NoSideEffect]>,
|
||||
LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> {
|
||||
let parser = [{ return parseUndefOp(parser, result); }];
|
||||
let printer = [{ printUndefOp(p, *this); }];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
def LLVM_ConstantOp
|
||||
: LLVM_OneResultOp<"mlir.constant", [NoSideEffect]>,
|
||||
Arguments<(ins AnyAttr:$value)>,
|
||||
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
|
||||
{
|
||||
let parser = [{ return parseConstantOp(parser, result); }];
|
||||
let printer = [{ printConstantOp(p, *this); }];
|
||||
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
// Operations that correspond to LLVM intrinsics. With MLIR operation set being
|
||||
|
|
|
@ -56,6 +56,7 @@ def Linalg_RangeOp :
|
|||
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
|
||||
}
|
||||
|
||||
def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
|
||||
|
@ -117,7 +118,9 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
|
|||
static StringRef getReassociationAttrName() { return "reassociation"; }
|
||||
MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$view $reassociation attr-dict `:` type($view) `into` type(results)
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -183,8 +183,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
||||
: Op<Linalg_Dialect, mnemonic,
|
||||
!listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
|
||||
let parser = [{ return parseLinalgStructuredOp(parser, result); }];
|
||||
let printer = [{ printLinalgStructuredOp(p, *this); }];
|
||||
}
|
||||
|
||||
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
|
||||
|
@ -194,6 +192,7 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
|
|||
return generateLibraryCallName(getOperation());
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -152,51 +152,6 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::GEPOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
|
||||
SmallVector<Type, 8> types(op.getOperandTypes());
|
||||
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
|
||||
|
||||
p << op.getOperationName() << ' ' << op.base() << '['
|
||||
<< op.getOperands().drop_front() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << funcTy;
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
|
||||
// attribute-dict? `:` type
|
||||
static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType base;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indices;
|
||||
Type type;
|
||||
llvm::SMLoc trailingTypeLoc;
|
||||
if (parser.parseOperand(base) ||
|
||||
parser.parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
|
||||
return failure();
|
||||
|
||||
// Deconstruct the trailing function type to extract the types of the base
|
||||
// pointer and result (same type) and the types of the indices.
|
||||
auto funcType = type.dyn_cast<FunctionType>();
|
||||
if (!funcType || funcType.getNumResults() != 1 ||
|
||||
funcType.getNumInputs() == 0)
|
||||
return parser.emitError(trailingTypeLoc,
|
||||
"expected trailing function type with at least "
|
||||
"one argument and one result");
|
||||
|
||||
if (parser.resolveOperand(base, funcType.getInput(0), result.operands) ||
|
||||
parser.resolveOperands(indices, funcType.getInputs().drop_front(),
|
||||
parser.getNameLoc(), result.operands))
|
||||
return failure();
|
||||
|
||||
result.addTypes(funcType.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::LoadOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -957,29 +912,7 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::UndefOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printUndefOp(OpAsmPrinter &p, UndefOp &op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.res().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.undef` attribute-dict? : type
|
||||
static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
|
||||
Type type;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
result.addTypes(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printer, parser and verifier for LLVM::AddressOfOp.
|
||||
// Verifier for LLVM::AddressOfOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp AddressOfOp::getGlobal() {
|
||||
|
@ -991,27 +924,6 @@ GlobalOp AddressOfOp::getGlobal() {
|
|||
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
|
||||
}
|
||||
|
||||
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
p.printSymbolName(op.global_name());
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"global_name"});
|
||||
p << " : " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseAddressOfOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Attribute symRef;
|
||||
Type type;
|
||||
if (parser.parseAttribute(symRef, "global_name", result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) || parser.addTypeToList(type, result.types))
|
||||
return failure();
|
||||
|
||||
if (!symRef.isa<SymbolRefAttr>())
|
||||
return parser.emitError(parser.getNameLoc(), "expected symbol reference");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AddressOfOp op) {
|
||||
auto global = op.getGlobal();
|
||||
if (!global)
|
||||
|
@ -1026,32 +938,6 @@ static LogicalResult verify(AddressOfOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::ConstantOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
|
||||
p << op.getOperationName() << '(' << op.value() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"value"});
|
||||
p << " : " << op.res().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type
|
||||
static ParseResult parseConstantOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Attribute valueAttr;
|
||||
Type type;
|
||||
|
||||
if (parser.parseLParen() ||
|
||||
parser.parseAttribute(valueAttr, "value", result.attributes) ||
|
||||
parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
result.addTypes(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Builder, printer and verifier for LLVM::GlobalOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1516,23 +1402,9 @@ static LogicalResult verify(LLVMFuncOp op) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing, parsing and verification for LLVM::NullOp.
|
||||
// Verification for LLVM::NullOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
|
||||
p << NullOp::getOperationName();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
// <operation> = `llvm.mlir.null` : type
|
||||
static ParseResult parseNullOp(OpAsmParser &parser, OperationState &result) {
|
||||
Type type;
|
||||
return failure(parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
// Only LLVM pointer types are supported.
|
||||
static LogicalResult verify(LLVM::NullOp op) {
|
||||
auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
|
|
@ -379,30 +379,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
|
||||
static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, RangeOp op) {
|
||||
p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":"
|
||||
<< op.step();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
|
||||
RangeType type;
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(parser.parseOperand(rangeInfo[0]) || parser.parseColon() ||
|
||||
parser.parseOperand(rangeInfo[1]) || parser.parseColon() ||
|
||||
parser.parseOperand(rangeInfo[2]) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperands(rangeInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -583,28 +559,6 @@ void mlir::linalg::ReshapeOp::build(
|
|||
b->getAffineMapArrayAttr(maps));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ReshapeOp op) {
|
||||
p << op.getOperationName() << " " << op.view() << " " << op.reassociation();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
{ReshapeOp::getReassociationAttrName()});
|
||||
p << " : " << op.getViewType() << " into " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType view;
|
||||
ArrayAttr reassociation;
|
||||
MemRefType type, resultType;
|
||||
return failure(parser.parseOperand(view) ||
|
||||
parser.parseAttribute(reassociation,
|
||||
ReshapeOp::getReassociationAttrName(),
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.parseKeywordType("into", resultType) ||
|
||||
parser.resolveOperand(view, type, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(ReshapeOp op) {
|
||||
MemRefType expandedType = op.getViewType();
|
||||
MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
|
||||
|
@ -838,43 +792,6 @@ static LogicalResult verify(YieldOp op) {
|
|||
}
|
||||
|
||||
/////// Operations corresponding to library calls defined with Tablegen ////////
|
||||
// For such operations correspond to library calls (i.e. defined in
|
||||
// LinalgStructuredOps.td), we define an overloaded `print` function and a
|
||||
// parse`className` function.
|
||||
|
||||
// A LinalgStructuredOp prints as:
|
||||
//
|
||||
// ```mlir
|
||||
// concrete_op_name (ssa-inputs, ssa-outputs) : view-types
|
||||
// ```
|
||||
//
|
||||
// for example:
|
||||
//
|
||||
// ```
|
||||
// linalg.matmul(%0, %1, %2) :
|
||||
// memref<?x?xf32, stride_specification>,
|
||||
// memref<?x?xf32, stride_specification>,
|
||||
// memref<?x?xf32, stride_specification>
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
|
||||
static void printLinalgStructuredOp(OpAsmPrinter &p, Operation *op) {
|
||||
assert(op->getAbstractOperation() && "unregistered operation");
|
||||
p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : " << op->getOperandTypes();
|
||||
}
|
||||
|
||||
static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
||||
SmallVector<Type, 3> types;
|
||||
return failure(
|
||||
parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types) ||
|
||||
parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(FillOp op) {
|
||||
auto viewType = op.getOutputShapedType(0);
|
||||
|
|
|
@ -118,7 +118,7 @@ func @bar() {
|
|||
func @foo() {
|
||||
// The attribute parser will consume the first colon-type, so we put two of
|
||||
// them to trigger the attribute type mismatch error.
|
||||
// expected-error @+1 {{expected symbol reference}}
|
||||
// expected-error @+1 {{invalid kind of attribute specified}}
|
||||
llvm.mlir.addressof "foo" : i64 : !llvm<"void ()*">
|
||||
}
|
||||
|
||||
|
|
|
@ -63,28 +63,28 @@ func @alloca_nonpositive_alignment(%size : !llvm.i64) {
|
|||
// -----
|
||||
|
||||
func @gep_missing_input_result_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
|
||||
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
|
||||
// expected-error@+1 {{2 operands present, but expected 0}}
|
||||
llvm.getelementptr %base[%pos] : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @gep_missing_input_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
|
||||
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
|
||||
// expected-error@+1 {{2 operands present, but expected 0}}
|
||||
llvm.getelementptr %base[%pos] : () -> (!llvm<"float*">)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @gep_missing_result_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
|
||||
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
|
||||
// expected-error@+1 {{op requires one result}}
|
||||
llvm.getelementptr %base[%pos] : (!llvm<"float *">, !llvm.i64) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @gep_non_function_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
|
||||
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
|
||||
// expected-error@+1 {{invalid kind of type specified}}
|
||||
llvm.getelementptr %base[%pos] : !llvm<"float*">
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
|
||||
|
||||
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue