[mlir][OpenMP] Restrict types for omp.parallel args

This patch restricts the value of `if` clause expression to an I1 value.
It also restricts the value of `num_threads` clause expression to an I32
value.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D124142
This commit is contained in:
Shraiysh Vaishay 2022-05-02 10:54:28 +05:30
parent c8603db071
commit a60fda59dc
4 changed files with 74 additions and 20 deletions

View File

@ -254,8 +254,10 @@ genOMP(Fortran::lower::AbstractConverter &converter,
if (const auto &ifClause =
std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
ifClauseOperand = fir::getBase(
mlir::Value ifVal = fir::getBase(
converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
ifClauseOperand = firOpBuilder.createConvert(
currentLocation, firOpBuilder.getI1Type(), ifVal);
} else if (const auto &numThreadsClause =
std::get_if<Fortran::parser::OmpClause::NumThreads>(
&clause.u)) {

View File

@ -15,8 +15,13 @@ end subroutine parallel_simple
!===============================================================================
!FIRDialect-LABEL: func @_QPparallel_if
subroutine parallel_if(alpha)
subroutine parallel_if(alpha, beta, gamma)
integer, intent(in) :: alpha
logical, intent(in) :: beta
logical(1) :: logical1
logical(2) :: logical2
logical(4) :: logical4
logical(8) :: logical8
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(alpha .le. 0)
@ -46,6 +51,41 @@ subroutine parallel_if(alpha)
!OMPDialect: omp.terminator
!$omp end parallel
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(beta)
!FIRDialect: fir.call
call f1()
!OMPDialect: omp.terminator
!$omp end parallel
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(logical1)
!FIRDialect: fir.call
call f1()
!OMPDialect: omp.terminator
!$omp end parallel
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(logical2)
!FIRDialect: fir.call
call f1()
!OMPDialect: omp.terminator
!$omp end parallel
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(logical4)
!FIRDialect: fir.call
call f1()
!OMPDialect: omp.terminator
!$omp end parallel
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(logical8)
!FIRDialect: fir.call
call f1()
!OMPDialect: omp.terminator
!$omp end parallel
end subroutine parallel_if
!===============================================================================

View File

@ -99,8 +99,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
of the parallel region.
}];
let arguments = (ins Optional<AnyType>:$if_expr_var,
Optional<AnyType>:$num_threads_var,
let arguments = (ins Optional<I1>:$if_expr_var,
Optional<IntLikeType>:$num_threads_var,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,

View File

@ -51,15 +51,15 @@ func.func @omp_terminator() -> () {
omp.terminator
}
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
// test without if condition
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
}) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
}) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (i32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@ -71,13 +71,13 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
}) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
}) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
}) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, i32) -> ()
omp.terminator
}) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
}) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, i32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
@ -88,14 +88,26 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
return
}
func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () {
func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %allocator : si32) -> () {
// CHECK: omp.parallel
omp.parallel {
omp.terminator
}
// CHECK: omp.parallel num_threads(%{{.*}} : si32)
omp.parallel num_threads(%num_threads : si32) {
// CHECK: omp.parallel num_threads(%{{.*}} : i32)
omp.parallel num_threads(%num_threads : i32) {
omp.terminator
}
%n_index = arith.constant 2 : index
// CHECK: omp.parallel num_threads(%{{.*}} : index)
omp.parallel num_threads(%n_index : index) {
omp.terminator
}
%n_i64 = arith.constant 4 : i64
// CHECK: omp.parallel num_threads(%{{.*}} : i64)
omp.parallel num_threads(%n_i64 : i64) {
omp.terminator
}
@ -113,8 +125,8 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
// CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) proc_bind(close) {
// CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) private(%{{.*}} : memref<i32>) proc_bind(close)
omp.parallel num_threads(%num_threads : i32) if(%if_cond: i1) proc_bind(close) {
omp.terminator
}
@ -347,14 +359,14 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
}
// CHECK-LABEL: omp_target
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
}) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, si32 ) -> ()
}) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, i32 ) -> ()
// CHECK: omp.barrier
omp.barrier
@ -363,14 +375,14 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> ()
}
// CHECK-LABEL: omp_target_pretty
func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
// CHECK: omp.target if({{.*}}) device({{.*}})
omp.target if(%if_cond) device(%device : si32) {
omp.terminator
}
// CHECK: omp.target if({{.*}}) device({{.*}}) nowait
omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : si32) nowait {
omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : i32) nowait {
omp.terminator
}