Fix minor issues in pair_snap_kokkos

This commit is contained in:
Stan Moore 2018-01-11 09:39:53 -07:00
parent d7d087ae67
commit db1ed32a51
2 changed files with 9 additions and 30 deletions

View File

@ -48,7 +48,6 @@ public:
void coeff(int, char**);
void init_style();
void compute(int, int);
double memory_usage();
template<int NEIGHFLAG, int EVFLAG>
KOKKOS_INLINE_FUNCTION
@ -83,7 +82,7 @@ protected:
// How much parallelism to use within an interaction
int vector_length;
int eflag,vflag,nlocal;
int eflag,vflag;
void allocate();
//void read_files(char *, char *);

View File

@ -133,7 +133,7 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
eflag = eflag_in;
vflag = vflag_in;
if (neighflag == FULL) no_virial_fdotr_compute = 1; // FIX ME??
if (neighflag == FULL) no_virial_fdotr_compute = 1;
if (eflag || vflag) ev_setup(eflag,vflag,0);
else evflag = vflag_fdotr = 0;
@ -160,13 +160,12 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
x = atomKK->k_x.view<DeviceType>();
f = atomKK->k_f.view<DeviceType>();
type = atomKK->k_type.view<DeviceType>();
nlocal = atom->nlocal;
NeighListKokkos<DeviceType>* k_list = static_cast<NeighListKokkos<DeviceType>*>(list);
d_numneigh = k_list->d_numneigh;
d_neighbors = k_list->d_neighbors;
d_ilist = k_list->d_ilist;
//int inum = list->inum;
int inum = list->inum;
/*
for (int i = 0; i < nlocal; i++) {
@ -175,7 +174,7 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
if (max_neighs<num_neighs) max_neighs = num_neighs;
}*/
int max_neighs = 0;
Kokkos::parallel_reduce("PairSNAPKokkos::find_max_neighs",nlocal, FindMaxNumNeighs<DeviceType>(k_list), Kokkos::Experimental::Max<int>(max_neighs));
Kokkos::parallel_reduce("PairSNAPKokkos::find_max_neighs",inum, FindMaxNumNeighs<DeviceType>(k_list), Kokkos::Experimental::Max<int>(max_neighs));
snaKK.nmax = max_neighs;
@ -197,13 +196,13 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
if (eflag) {
if (neighflag == HALF) {
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALF,1> > policy(nlocal,team_size,vector_length);
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALF,1> > policy(inum,team_size,vector_length);
Kokkos::parallel_reduce(policy
.set_scratch_size(1,Kokkos::PerThread(thread_scratch_size))
.set_scratch_size(1,Kokkos::PerTeam(team_scratch_size))
,*this,ev);
} else if (neighflag == HALFTHREAD) {
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALFTHREAD,1> > policy(nlocal,team_size,vector_length);
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALFTHREAD,1> > policy(inum,team_size,vector_length);
Kokkos::parallel_reduce(policy
.set_scratch_size(1,Kokkos::PerThread(thread_scratch_size))
.set_scratch_size(1,Kokkos::PerTeam(team_scratch_size))
@ -211,13 +210,13 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
}
} else {
if (neighflag == HALF) {
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALF,0> > policy(nlocal,team_size,vector_length);
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALF,0> > policy(inum,team_size,vector_length);
Kokkos::parallel_for(policy
.set_scratch_size(1,Kokkos::PerThread(thread_scratch_size))
.set_scratch_size(1,Kokkos::PerTeam(team_scratch_size))
,*this);
} else if (neighflag == HALFTHREAD) {
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALFTHREAD,0> > policy(nlocal,team_size,vector_length);
typename Kokkos::TeamPolicy<DeviceType, TagPairSNAP<HALFTHREAD,0> > policy(inum,team_size,vector_length);
Kokkos::parallel_for(policy
.set_scratch_size(1,Kokkos::PerThread(thread_scratch_size))
.set_scratch_size(1,Kokkos::PerTeam(team_scratch_size))
@ -615,23 +614,4 @@ void PairSNAPKokkos<DeviceType>::v_tally_xyz(EV_FLOAT &ev, const int &i, const i
v_vatom(j,4) += 0.5*v4;
v_vatom(j,5) += 0.5*v5;
}
}
/* ----------------------------------------------------------------------
memory usage
------------------------------------------------------------------------- */
template<class DeviceType>
double PairSNAPKokkos<DeviceType>::memory_usage()
{
double bytes = Pair::memory_usage();
int n = atom->ntypes+1;
bytes += n*n*sizeof(int);
bytes += n*n*sizeof(double);
bytes += 3*nmax*sizeof(double);
bytes += nmax*sizeof(int);
bytes += (2*ncoeffall)*sizeof(double);
bytes += (ncoeff*3)*sizeof(double);
//bytes += snaKK.memory_usage(); // FIXME
return bytes;
}
}