[mlir][NFC] Use declarative format for several operations in LLVM and Linalg dialects

Differential Revision: https://reviews.llvm.org/D73503
This commit is contained in:
River Riddle 2020-01-30 11:32:04 -08:00
parent 82170d5619
commit 528adb2e48
8 changed files with 20 additions and 232 deletions

View File

@ -269,8 +269,9 @@ def LLVM_AllocaOp :
def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> {
let parser = [{ return parseGEPOp(parser, result); }];
let printer = [{ printGEPOp(p, *this); }];
let assemblyFormat = [{
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
}];
}
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
@ -541,8 +542,7 @@ def LLVM_AddressOfOp
GlobalOp getGlobal();
}];
let printer = "printAddressOfOp(p, *this);";
let parser = "return parseAddressOfOp(parser, result);";
let assemblyFormat = "$global_name attr-dict `:` type($res)";
let verifier = "return ::verify(*this);";
}
@ -659,23 +659,20 @@ def LLVM_NullOp
: LLVM_OneResultOp<"mlir.null", [NoSideEffect]>,
LLVM_Builder<"$res = llvm::ConstantPointerNull::get("
" cast<llvm::PointerType>($_resultType));"> {
let parser = [{ return parseNullOp(parser, result); }];
let printer = [{ printNullOp(p, *this); }];
let assemblyFormat = "attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
}
def LLVM_UndefOp : LLVM_OneResultOp<"mlir.undef", [NoSideEffect]>,
LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> {
let parser = [{ return parseUndefOp(parser, result); }];
let printer = [{ printUndefOp(p, *this); }];
let assemblyFormat = "attr-dict `:` type($res)";
}
def LLVM_ConstantOp
: LLVM_OneResultOp<"mlir.constant", [NoSideEffect]>,
Arguments<(ins AnyAttr:$value)>,
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
{
let parser = [{ return parseConstantOp(parser, result); }];
let printer = [{ printConstantOp(p, *this); }];
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
}
// Operations that correspond to LLVM intrinsics. With MLIR operation set being

View File

@ -56,6 +56,7 @@ def Linalg_RangeOp :
// Fully specified by traits.
let verifier = ?;
let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
}
def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
@ -117,7 +118,9 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
static StringRef getReassociationAttrName() { return "reassociation"; }
MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
}];
let assemblyFormat = [{
$view $reassociation attr-dict `:` type($view) `into` type(results)
}];
let hasFolder = 1;
}

View File

@ -183,8 +183,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic,
!listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
let parser = [{ return parseLinalgStructuredOp(parser, result); }];
let printer = [{ printLinalgStructuredOp(p, *this); }];
}
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
@ -194,6 +192,7 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
return generateLibraryCallName(getOperation());
}
}];
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -152,51 +152,6 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
p << op.getOperationName() << ' ' << op.base() << '['
<< op.getOperands().drop_front() << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << funcTy;
}
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
// attribute-dict? `:` type
static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType base;
SmallVector<OpAsmParser::OperandType, 8> indices;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser.parseOperand(base) ||
parser.parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
// Deconstruct the trailing function type to extract the types of the base
// pointer and result (same type) and the types of the indices.
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType || funcType.getNumResults() != 1 ||
funcType.getNumInputs() == 0)
return parser.emitError(trailingTypeLoc,
"expected trailing function type with at least "
"one argument and one result");
if (parser.resolveOperand(base, funcType.getInput(0), result.operands) ||
parser.resolveOperands(indices, funcType.getInputs().drop_front(),
parser.getNameLoc(), result.operands))
return failure();
result.addTypes(funcType.getResults());
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
@ -957,29 +912,7 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::UndefOp.
//===----------------------------------------------------------------------===//
static void printUndefOp(OpAsmPrinter &p, UndefOp &op) {
p << op.getOperationName();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.res().getType();
}
// <operation> ::= `llvm.mlir.undef` attribute-dict? : type
static ParseResult parseUndefOp(OpAsmParser &parser, OperationState &result) {
Type type;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return failure();
result.addTypes(type);
return success();
}
//===----------------------------------------------------------------------===//
// Printer, parser and verifier for LLVM::AddressOfOp.
// Verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
@ -991,27 +924,6 @@ GlobalOp AddressOfOp::getGlobal() {
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
}
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
p << op.getOperationName() << " ";
p.printSymbolName(op.global_name());
p.printOptionalAttrDict(op.getAttrs(), {"global_name"});
p << " : " << op.getResult().getType();
}
static ParseResult parseAddressOfOp(OpAsmParser &parser,
OperationState &result) {
Attribute symRef;
Type type;
if (parser.parseAttribute(symRef, "global_name", result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) || parser.addTypeToList(type, result.types))
return failure();
if (!symRef.isa<SymbolRefAttr>())
return parser.emitError(parser.getNameLoc(), "expected symbol reference");
return success();
}
static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
if (!global)
@ -1026,32 +938,6 @@ static LogicalResult verify(AddressOfOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
p << op.getOperationName() << '(' << op.value() << ')';
p.printOptionalAttrDict(op.getAttrs(), {"value"});
p << " : " << op.res().getType();
}
// <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type
static ParseResult parseConstantOp(OpAsmParser &parser,
OperationState &result) {
Attribute valueAttr;
Type type;
if (parser.parseLParen() ||
parser.parseAttribute(valueAttr, "value", result.attributes) ||
parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return failure();
result.addTypes(type);
return success();
}
//===----------------------------------------------------------------------===//
// Builder, printer and verifier for LLVM::GlobalOp.
//===----------------------------------------------------------------------===//
@ -1516,23 +1402,9 @@ static LogicalResult verify(LLVMFuncOp op) {
}
//===----------------------------------------------------------------------===//
// Printing, parsing and verification for LLVM::NullOp.
// Verification for LLVM::NullOp.
//===----------------------------------------------------------------------===//
static void printNullOp(OpAsmPrinter &p, LLVM::NullOp op) {
p << NullOp::getOperationName();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}
// <operation> = `llvm.mlir.null` : type
static ParseResult parseNullOp(OpAsmParser &parser, OperationState &result) {
Type type;
return failure(parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.addTypeToList(type, result.types));
}
// Only LLVM pointer types are supported.
static LogicalResult verify(LLVM::NullOp op) {
auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();

View File

@ -379,30 +379,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
//===----------------------------------------------------------------------===//
// RangeOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RangeOp op) {
p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":"
<< op.step();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getResult().getType();
}
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto indexTy = parser.getBuilder().getIndexType();
return failure(parser.parseOperand(rangeInfo[0]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[1]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[2]) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(rangeInfo, indexTy, result.operands) ||
parser.addTypeToList(type, result.types));
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@ -583,28 +559,6 @@ void mlir::linalg::ReshapeOp::build(
b->getAffineMapArrayAttr(maps));
}
static void print(OpAsmPrinter &p, ReshapeOp op) {
p << op.getOperationName() << " " << op.view() << " " << op.reassociation();
p.printOptionalAttrDict(op.getAttrs(),
{ReshapeOp::getReassociationAttrName()});
p << " : " << op.getViewType() << " into " << op.getResult().getType();
}
static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType view;
ArrayAttr reassociation;
MemRefType type, resultType;
return failure(parser.parseOperand(view) ||
parser.parseAttribute(reassociation,
ReshapeOp::getReassociationAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.parseKeywordType("into", resultType) ||
parser.resolveOperand(view, type, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static LogicalResult verify(ReshapeOp op) {
MemRefType expandedType = op.getViewType();
MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
@ -838,43 +792,6 @@ static LogicalResult verify(YieldOp op) {
}
/////// Operations corresponding to library calls defined with Tablegen ////////
// For such operations correspond to library calls (i.e. defined in
// LinalgStructuredOps.td), we define an overloaded `print` function and a
// parse`className` function.
// A LinalgStructuredOp prints as:
//
// ```mlir
// concrete_op_name (ssa-inputs, ssa-outputs) : view-types
// ```
//
// for example:
//
// ```
// linalg.matmul(%0, %1, %2) :
// memref<?x?xf32, stride_specification>,
// memref<?x?xf32, stride_specification>,
// memref<?x?xf32, stride_specification>
// ```
//
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
static void printLinalgStructuredOp(OpAsmPrinter &p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getOperandTypes();
}
static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);

View File

@ -118,7 +118,7 @@ func @bar() {
func @foo() {
// The attribute parser will consume the first colon-type, so we put two of
// them to trigger the attribute type mismatch error.
// expected-error @+1 {{expected symbol reference}}
// expected-error @+1 {{invalid kind of attribute specified}}
llvm.mlir.addressof "foo" : i64 : !llvm<"void ()*">
}

View File

@ -63,28 +63,28 @@ func @alloca_nonpositive_alignment(%size : !llvm.i64) {
// -----
func @gep_missing_input_result_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
// expected-error@+1 {{2 operands present, but expected 0}}
llvm.getelementptr %base[%pos] : () -> ()
}
// -----
func @gep_missing_input_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
// expected-error@+1 {{2 operands present, but expected 0}}
llvm.getelementptr %base[%pos] : () -> (!llvm<"float*">)
}
// -----
func @gep_missing_result_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
// expected-error@+1 {{op requires one result}}
llvm.getelementptr %base[%pos] : (!llvm<"float *">, !llvm.i64) -> ()
}
// -----
func @gep_non_function_type(%pos : !llvm.i64, %base : !llvm<"float*">) {
// expected-error@+1 {{expected trailing function type with at least one argument and one result}}
// expected-error@+1 {{invalid kind of type specified}}
llvm.getelementptr %base[%pos] : !llvm<"float*">
}

View File

@ -29,7 +29,7 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
return
}
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = constant 0 : index