[flang] Update fir.dispatch operation

Update the `fir.dispatch` operation to prepare
the lowering part. `nopass` and `pass_arg_pos` attributes
are added in the arguments list so accessors are generated
by MLIR tablegen. A verifier is added as well as some tests.

This patch is part of the implementation of the poltymorphic
entities.
https://github.com/llvm/llvm-project/blob/main/flang/docs/PolymorphicEntities.md

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D135358
This commit is contained in:
Valentin Clement 2022-10-06 18:10:33 +02:00
parent 2d4fd0b6d5
commit 5a0722e046
No known key found for this signature in database
GPG Key ID: 086D54783C928776
5 changed files with 74 additions and 23 deletions

View File

@ -2327,22 +2327,30 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
let description = [{
Perform a dynamic dispatch on the method name via the dispatch table
associated with the first argument. The attribute 'pass_arg_pos' can be
used to select a dispatch argument other than the first one.
associated with the first operand. The attribute `pass_arg_pos` can be
used to select a dispatch operand other than the first one. The absence of
`pass_arg_pos` attribute means nopass.
```mlir
%r = fir.dispatch methodA(%o) : (!fir.box<none>) -> i32
// fir.dispatch with no attribute.
%r = fir.dispatch "methodA"(%o) : (!fir.class<T>) -> i32
// fir.dispatch with the `pass_arg_pos` attribute.
%r = fir.dispatch "methodA"(%o, %o) : (!fir.class<T>, !fir.class<T>) -> i32 {pass_arg_pos = 0 : i32}
```
}];
let arguments = (ins
StrAttr:$method,
fir_BoxType:$object,
Variadic<AnyType>:$args
fir_ClassType:$object,
Variadic<AnyType>:$args,
OptionalAttr<I32Attr>:$pass_arg_pos
);
let results = (outs Variadic<AnyType>);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
@ -2350,14 +2358,10 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
// operand[0] is the object (of box type)
// operand[0] is the object (of class type)
operand_iterator arg_operand_begin() { return operand_begin() + 1; }
operand_iterator arg_operand_end() { return operand_end(); }
static constexpr llvm::StringRef getPassArgAttrName() {
return "pass_arg_pos";
}
static constexpr llvm::StringRef getMethodAttrNameStr() { return "method"; }
unsigned passArgPos();
}];
}

View File

@ -1038,6 +1038,20 @@ mlir::LogicalResult fir::CoordinateOp::verify() {
// DispatchOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult fir::DispatchOp::verify() {
// Check that pass_arg_pos is in range of actual operands. pass_arg_pos is
// unsigned so check for less than zero is not needed.
if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1))
return emitOpError(
"pass_arg_pos must be smaller than the number of operands");
// Operand pointed by pass_arg_pos must have polymorphic type.
if (getPassArgPos() &&
!fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType()))
return emitOpError("pass_arg_pos must be a polymorphic operand");
return mlir::success();
}
mlir::FunctionType fir::DispatchOp::getFunctionType() {
return mlir::FunctionType::get(getContext(), getOperandTypes(),
getResultTypes());
@ -1060,11 +1074,11 @@ mlir::ParseResult fir::DispatchOp::parse(mlir::OpAsmParser &parser,
parser.getBuilder().getStringAttr(calleeName));
}
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(calleeType) ||
parser.addTypesToList(calleeType.getResults(), result.types) ||
parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result.operands))
result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
return mlir::failure();
return mlir::success();
}
@ -1079,6 +1093,9 @@ void fir::DispatchOp::print(mlir::OpAsmPrinter &p) {
p << ") : ";
p.printFunctionalType(getOperation()->getOperandTypes(),
getOperation()->getResultTypes());
p.printOptionalAttrDict(getOperation()->getAttrs(),
{mlir::SymbolTable::getSymbolAttrName(),
fir::DispatchOp::getMethodAttrNameStr()});
}
//===----------------------------------------------------------------------===//

View File

@ -3,8 +3,8 @@
// Test `fir.dispatch` conversion to llvm.
// Not implemented yet.
func.func @dispatch(%arg0: !fir.box<!fir.type<derived3{f:f32}>>) {
// CHECK: not yet implemented: fir.dispatch codegen
%0 = fir.dispatch "method"(%arg0) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
func.func @dispatch(%arg0: !fir.class<!fir.type<derived3{f:f32}>>) {
// CHECK: not yet implemented: fir.class type conversion
%0 = fir.dispatch "method"(%arg0) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
return
}

View File

@ -114,14 +114,14 @@ func.func @instructions() {
%25 = fir.insert_value %22, %cf1, ["f", !fir.type<derived{f:f32}>] : (!fir.type<derived{f:f32}>, f32) -> !fir.type<derived{f:f32}>
%26 = fir.len_param_index f, !fir.type<derived3{f:f32}>
// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.box<!fir.type<derived3{f:f32}>>
// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.class<!fir.type<derived3{f:f32}>>
// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
// CHECK: [[VAL_33:%.*]] = fir.convert [[VAL_32]] : (i32) -> i64
// CHECK: [[VAL_34:%.*]] = fir.gentypedesc !fir.type<x>
// CHECK: fir.call @user_tdesc([[VAL_34]]) : (!fir.tdesc<!fir.type<x>>) -> ()
// CHECK: [[VAL_35:%.*]] = fir.no_reassoc [[VAL_33]] : i64
%27 = fir.call @box3() : () -> !fir.box<!fir.type<derived3{f:f32}>>
%28 = fir.dispatch "method"(%27) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
%27 = fir.call @box3() : () -> !fir.class<!fir.type<derived3{f:f32}>>
%28 = fir.dispatch "method"(%27) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
%29 = fir.convert %28 : (i32) -> i64
%30 = fir.gentypedesc !fir.type<x>
fir.call @user_tdesc(%30) : (!fir.tdesc<!fir.type<x>>) -> ()
@ -309,12 +309,12 @@ func.func @bar_select_rank(%arg : i32, %arg2 : i32) -> i32 {
// CHECK: ^bb5:
// CHECK: [[VAL_99:%.*]] = arith.constant 0 : i32
// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.box<!fir.type<derived3{f:f32}>>
// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.box<!fir.type<derived3{f:f32}>>) -> ()
// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.class<!fir.type<derived3{f:f32}>>
// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.class<!fir.type<derived3{f:f32}>>) -> ()
^bb5 :
%zero = arith.constant 0 : i32
%7 = fir.call @get_method_box() : () -> !fir.box<!fir.type<derived3{f:f32}>>
fir.dispatch method(%7) : (!fir.box<!fir.type<derived3{f:f32}>>) -> ()
%7 = fir.call @get_method_box() : () -> !fir.class<!fir.type<derived3{f:f32}>>
fir.dispatch method(%7) : (!fir.class<!fir.type<derived3{f:f32}>>) -> ()
// CHECK: return [[VAL_99]] : i32
// CHECK: }
@ -805,3 +805,17 @@ func.func @array_amend_ops(%a : !fir.ref<!fir.array<?x?xf32>>) {
// CHECK: %{{.*}} = fir.array_amend %{{.*}}, %{{.*}} : (!fir.array<?x?xf32>, !fir.ref<f32>) -> !fir.array<?x?xf32>
return
}
func.func private @dispatch(%arg0: !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, %arg1: i32) -> () {
// CHECK-LABEL: func.func private @dispatch(
// CHECK-SAME: %[[CLASS:.*]]: !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, %[[INTARG:.*]]: i32)
fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 0 : i32}
// CHECK: fir.dispatch "proc1"(%[[CLASS]], %[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 0 : i32}
fir.dispatch "proc2"(%arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {nopass}
// CHECK: fir.dispatch "proc2"(%[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {nopass}
fir.dispatch "proc3"(%arg0, %arg1, %arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, i32, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
// CHECK: fir.dispatch "proc3"(%[[CLASS]], %[[INTARG]], %[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, i32, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
return
}

View File

@ -756,3 +756,19 @@ func.func @foo(%arg0: !fir.ref<!fir.array<30x!fir.type<t{c:!fir.array<20xi32>}>>
return
}
func.func private @ifoo(!fir.ref<f32>) -> i32
// -----
func.func private @dispatch(%arg0: !fir.class<!fir.type<derived{a:i32,b:i32}>>) -> () {
// expected-error@+1 {{'fir.dispatch' op pass_arg_pos must be smaller than the number of operands}}
fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class<!fir.type<derived{a:i32,b:i32}>>, !fir.class<!fir.type<derived{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
return
}
// -----
func.func private @dispatch(%arg0: !fir.class<!fir.type<derived{a:i32,b:i32}>>, %arg1: i32) -> () {
// expected-error@+1 {{'fir.dispatch' op pass_arg_pos must be a polymorphic operand}}
fir.dispatch "proc1"(%arg0, %arg0, %arg1) : (!fir.class<!fir.type<derived{a:i32,b:i32}>>, !fir.class<!fir.type<derived{a:i32,b:i32}>>, i32) -> () {pass_arg_pos = 1 : i32}
return
}