diff --git a/src/KOKKOS/comm_kokkos.cpp b/src/KOKKOS/comm_kokkos.cpp index 3ec22c42fa..5dc1e5fa4a 100644 --- a/src/KOKKOS/comm_kokkos.cpp +++ b/src/KOKKOS/comm_kokkos.cpp @@ -71,6 +71,10 @@ CommKokkos::CommKokkos(LAMMPS *lmp) : CommBrick(lmp) maxsendlist[i] = BUFMIN; } memory->create_kokkos(k_sendlist,sendlist,maxswap,BUFMIN,"comm:sendlist"); + + max_buf_pair = 0; + k_buf_send_pair = DAT::tdual_xfloat_1d("comm:k_buf_send_pair",1); + k_buf_recv_pair = DAT::tdual_xfloat_1d("comm:k_recv_send_pair",1); } /* ---------------------------------------------------------------------- */ @@ -300,9 +304,13 @@ void CommKokkos::forward_comm_pair_device(Pair *pair) int nsize = pair->comm_forward; for (iswap = 0; iswap < nswap; iswap++) { + int n = MAX(max_buf_pair,nsize*sendnum[iswap]); + n = MAX(n,nsize*recvnum[iswap]); + if (n > max_buf_pair) + grow_buf_pair(n); + } - DAT::tdual_xfloat_1d k_buf_send_pair = DAT::tdual_xfloat_1d("comm:k_buf_send_pair",nsize*sendnum[iswap]); - DAT::tdual_xfloat_1d k_buf_recv_pair = DAT::tdual_xfloat_1d("comm:k_recv_send_pair",nsize*recvnum[iswap]); + for (iswap = 0; iswap < nswap; iswap++) { // pack buffer @@ -327,6 +335,12 @@ void CommKokkos::forward_comm_pair_device(Pair *pair) } } +void CommKokkos::grow_buf_pair(int n) { + max_buf_pair = n * BUFFACTOR; + k_buf_send_pair.resize(max_buf_pair); + k_buf_recv_pair.resize(max_buf_pair); +} + void CommKokkos::reverse_comm_pair(Pair *pair) { k_sendlist.sync(); diff --git a/src/KOKKOS/comm_kokkos.h b/src/KOKKOS/comm_kokkos.h index 587b030595..71d5e59595 100644 --- a/src/KOKKOS/comm_kokkos.h +++ b/src/KOKKOS/comm_kokkos.h @@ -59,6 +59,11 @@ class CommKokkos : public CommBrick { //double *buf_send; // send buffer for all comm //double *buf_recv; // recv buffer for all comm + int max_buf_pair; + DAT::tdual_xfloat_1d k_buf_send_pair; + DAT::tdual_xfloat_1d k_buf_recv_pair; + void grow_buf_pair(int); + void grow_send(int, int); void grow_recv(int); void grow_send_kokkos(int, int, ExecutionSpace space = Host);