master sponge performance

This commit is contained in:
mamba_ni 2021-04-01 15:05:39 +08:00
parent 18d79d35b6
commit 28811fa958
40 changed files with 4381 additions and 3689 deletions

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/crd_to_uint_crd_impl.cuh"
__global__ void Crd_To_Uint_Crd(const int atom_numbers, const VECTOR *scale_factor, const VECTOR *crd,
UNSIGNED_INT_VECTOR *uint_crd) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
uint_crd[atom_i].uint_x = crd[atom_i].x * scale_factor[0].x;
uint_crd[atom_i].uint_y = crd[atom_i].y * scale_factor[0].y;
uint_crd[atom_i].uint_z = crd[atom_i].z * scale_factor[0].z;
/*uint_crd[atom_i].uint_x = 2 * uint_crd[atom_i].uint_x;
uint_crd[atom_i].uint_y = 2 * uint_crd[atom_i].uint_y;
uint_crd[atom_i].uint_z = 2 * uint_crd[atom_i].uint_z;*/
uint_crd[atom_i].uint_x = uint_crd[atom_i].uint_x << 1;
uint_crd[atom_i].uint_y = uint_crd[atom_i].uint_y << 1;
uint_crd[atom_i].uint_z = uint_crd[atom_i].uint_z << 1;
}
}
void CrdToUintCrd(const int atom_numbers, const float *crd_to_uint_crd_cof_f, const float *crd_f,
unsigned int *uint_crd_f, cudaStream_t stream) {
VECTOR *crd = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(crd_f));
VECTOR *crd_to_uint_crd_cof = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(crd_to_uint_crd_cof_f));
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
Crd_To_Uint_Crd<<<ceilf(static_cast<float>(atom_numbers) / 128.0), 128, 0, stream>>>(
atom_numbers, crd_to_uint_crd_cof, crd, uint_crd);
return;
}
void CrdToUintCrd(const int atom_numbers, const float *crd_to_uint_crd_cof_f, const float *crd_f,
unsigned int *uint_crd_f, cudaStream_t stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_CRD_TO_UINT_CRD_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_CRD_TO_UINT_CRD_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void CrdToUintCrd(const int atom_numbers, const float *crd_to_uint_crd_cof_f, const float *crd_f,
unsigned int *uint_crd_f, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_CRD_TO_UINT_CRD_IMPL_H_

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cuh"
#include <vector>
__global__ void Copy_List(const int element_numbers, const int *origin_list, int *list) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
@ -387,7 +387,7 @@ __global__ void Mul_half(float *src, float *dst) {
}
}
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int refresh_count, int refresh_interval,
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int *d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square, float cutoff_with_skin_square,
int *grid_N, float *box_length, int *atom_numbers_in_grid_bucket, float *grid_length_inverse,
int *atom_in_grid_serial, GRID_BUCKET *bucket, float *crd, float *old_crd,
@ -397,15 +397,22 @@ void Neighbor_List_Update(int grid_numbers, int atom_numbers, int refresh_count,
int *is_need_refresh_neighbor_list, cudaStream_t stream) {
if (not_first_time) {
if (refresh_interval > 0) {
std::vector<int> refresh_count_list(1);
cudaMemcpyAsync(refresh_count_list.data(), d_refresh_count, sizeof(int), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
int refresh_count = refresh_count_list[0];
if (refresh_count % refresh_interval == 0) {
Mul_half<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List_No_Check(
grid_numbers, atom_numbers, skin, Nxy, cutoff_square, grid_N, box_length, atom_numbers_in_grid_bucket,
grid_length_inverse, atom_in_grid_serial, bucket, reinterpret_cast<VECTOR *>(crd),
reinterpret_cast<VECTOR *>(old_crd), crd_to_uint_crd_cof, reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
uint_dr_to_dr_cof, gpointer, d_nl, excluded_list_start, excluded_list, excluded_numbers, stream);
Refresh_Neighbor_List_No_Check(grid_numbers, atom_numbers, skin, Nxy, cutoff_square, grid_N, box_length,
atom_numbers_in_grid_bucket, grid_length_inverse, atom_in_grid_serial, bucket,
reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd),
half_crd_to_uint_crd_cof, reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
uint_dr_to_dr_cof, gpointer, d_nl, excluded_list_start, excluded_list,
excluded_numbers, stream);
}
refresh_count += 1;
cudaMemcpyAsync(d_refresh_count, &refresh_count, sizeof(int), cudaMemcpyHostToDevice, stream);
} else {
Is_need_refresh_neighbor_list_cuda<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd), half_skin_square,

View File

@ -48,7 +48,7 @@ void Construct_Neighbor_List(int grid_numbers, int max_neighbor_numbers, int *nl
void CopyNeighborListAtomNumber(int atom_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers, cudaStream_t stream);
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int refresh_count, int refresh_interval,
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int* d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square, float cutoff_with_skin_square,
int *grid_N, float *box_length, int *atom_numbers_in_grid_bucket, float *grid_length_inverse,
int *atom_in_grid_serial, GRID_BUCKET *bucket, float *crd, float *old_crd,

View File

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nvtit/md_iteration_leap_frog_liujian_gpu_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
__global__ void MD_Iteration_Leap_Frog_With_LiuJian_kernel(const int atom_numbers, const float half_dt, const float dt,
const float exp_gamma, float *inverse_mass,
float *sqrt_mass_inverse, VECTOR *vel, VECTOR *crd,
VECTOR *frc, VECTOR *acc, VECTOR *random_frc,
VECTOR *output) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < atom_numbers) {
acc[i].x = inverse_mass[i] * frc[i].x;
acc[i].y = inverse_mass[i] * frc[i].y;
acc[i].z = inverse_mass[i] * frc[i].z;
vel[i].x = vel[i].x + dt * acc[i].x;
vel[i].y = vel[i].y + dt * acc[i].y;
vel[i].z = vel[i].z + dt * acc[i].z;
output[i].x = crd[i].x + half_dt * vel[i].x;
output[i].y = crd[i].y + half_dt * vel[i].y;
output[i].z = crd[i].z + half_dt * vel[i].z;
vel[i].x = exp_gamma * vel[i].x + sqrt_mass_inverse[i] * random_frc[i].x;
vel[i].y = exp_gamma * vel[i].y + sqrt_mass_inverse[i] * random_frc[i].y;
vel[i].z = exp_gamma * vel[i].z + sqrt_mass_inverse[i] * random_frc[i].z;
output[i].x = output[i].x + half_dt * vel[i].x;
output[i].y = output[i].y + half_dt * vel[i].y;
output[i].z = output[i].z + half_dt * vel[i].z;
}
}
void MD_Iteration_Leap_Frog_With_LiuJian(const int atom_numbers, const float half_dt, const float dt,
const float exp_gamma, int float4_numbers, float *inverse_mass,
float *sqrt_mass_inverse, float *vel, float *crd, float *frc, float *acc,
curandStatePhilox4_32_10_t *rand_state, float *rand_frc, float *output,
cudaStream_t stream) {
Rand_Normal<<<ceilf(static_cast<float>(float4_numbers) / 32.), 32, 0, stream>>>(float4_numbers, rand_state,
reinterpret_cast<float4 *>(rand_frc));
VECTOR *d_vel = reinterpret_cast<VECTOR *>(vel);
VECTOR *d_crd = reinterpret_cast<VECTOR *>(crd);
VECTOR *d_frc = reinterpret_cast<VECTOR *>(frc);
VECTOR *d_acc = reinterpret_cast<VECTOR *>(acc);
VECTOR *d_rand_frc = reinterpret_cast<VECTOR *>(rand_frc);
VECTOR *d_out = reinterpret_cast<VECTOR *>(output);
MD_Iteration_Leap_Frog_With_LiuJian_kernel<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, half_dt, dt, exp_gamma, inverse_mass, sqrt_mass_inverse, d_vel, d_crd, d_frc, d_acc, d_rand_frc,
d_out);
}

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_MD_ITERATION_LEAP_FROG_LIUJIAN_GPU_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_MD_ITERATION_LEAP_FROG_LIUJIAN_GPU_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void MD_Iteration_Leap_Frog_With_LiuJian(const int atom_numbers, const float half_dt, const float dt,
const float exp_gamma, int float4_numbers, float *inverse_mass,
float *sqrt_mass_inverse, float *vel, float *crd, float *frc, float *acc,
curandStatePhilox4_32_10_t *rand_state, float *rand_frc, float *output,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_MD_ITERATION_LEAP_FROG_LIUJIAN_GPU_IMPL_H_

View File

@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nvtit/md_iteration_setup_random_state_gpu_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
void MD_Iteration_Setup_Random_State(int float4_numbers, curandStatePhilox4_32_10_t *rand_state, int seed,
cudaStream_t stream) {
Setup_Rand_Normal_Kernel<<<ceilf(static_cast<float>(float4_numbers) / 32.), 32, 0, stream>>>(float4_numbers,
rand_state, seed);
}
void MD_Iteration_Setup_Random_State(int float4_numbers, curandStatePhilox4_32_10_t *rand_state, int seed,
cudaStream_t stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_MD_ITERATION_SETUP_RANDOM_STATE_GPU_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_MD_ITERATION_SETUP_RANDOM_STATE_GPU_IMPL_H_
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void MD_Iteration_Setup_Random_State(int float4_numbers, curandStatePhilox4_32_10_t *rand_state, int seed,
cudaStream_t stream);
#endif

View File

@ -93,12 +93,13 @@ __global__ void PME_Excluded_Energy_Correction(const int atom_numbers, const UNS
}
}
void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *box_length_f, float *PME_BC,
int *pme_uxyz, float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const int *uint_crd_f, const float *charge, int *nl_atom_numbers, int *nl_atom_serial, int *nl,
const float *scaler_f, const int *excluded_list_start, const int *excluded_list,
const int *excluded_atom_numbers, float *d_reciprocal_ene, float *d_self_ene, float *d_direct_ene,
float *d_correction_ene, cudaStream_t stream) {
void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *PME_BC, int *pme_uxyz,
float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz, const int *uint_crd_f,
const float *charge, int *nl_atom_numbers, int *nl_atom_serial, int *nl, const float *scaler_f,
const int *excluded_list_start, const int *excluded_list, const int *excluded_atom_numbers,
float *d_reciprocal_ene, float *d_self_ene, float *d_direct_ene, float *d_correction_ene,
dim3 thread_PME, int PME_Nin, int PME_Nfft, int PME_Nall, const cufftHandle &PME_plan_r2c,
const cufftHandle &PME_plan_c2r, cudaStream_t stream) {
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
VECTOR *scaler = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(scaler_f));
@ -106,97 +107,11 @@ void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float
NEIGHBOR_LIST *nl_a = reinterpret_cast<NEIGHBOR_LIST *>(nl);
construct_neighbor_list_kernel<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl_a);
std::vector<float> h_box_length(3);
cudaMemcpyAsync(h_box_length.data(), box_length_f, sizeof(float) * h_box_length.size(), cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);
VECTOR *box_length = reinterpret_cast<VECTOR *>(h_box_length.data());
UNSIGNED_INT_VECTOR *PME_uxyz = reinterpret_cast<UNSIGNED_INT_VECTOR *>(pme_uxyz);
UNSIGNED_INT_VECTOR *PME_kxyz = reinterpret_cast<UNSIGNED_INT_VECTOR *>(pme_kxyz);
VECTOR *PME_frxyz = reinterpret_cast<VECTOR *>(pme_frxyz);
cufftComplex *PME_FQ = reinterpret_cast<cufftComplex *>(pme_fq);
cufftHandle PME_plan_r2c;
cufftHandle PME_plan_c2r;
cufftPlan3d(&PME_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
cufftPlan3d(&PME_plan_c2r, fftx, ffty, fftz, CUFFT_C2R);
cufftSetStream(PME_plan_r2c, stream);
cufftSetStream(PME_plan_c2r, stream);
thread_PME.x = 8;
thread_PME.y = 8;
int PME_Nin = ffty * fftz;
int PME_Nfft = fftx * ffty * (fftz / 2 + 1);
int PME_Nall = fftx * ffty * fftz;
float volume = box_length[0].x * box_length[0].y * box_length[0].z;
UNSIGNED_INT_VECTOR *PME_kxyz_cpu;
Malloc_Safely(reinterpret_cast<void **>(&PME_kxyz_cpu), sizeof(UNSIGNED_INT_VECTOR) * 64);
int kx, ky, kz, kxrp, kyrp, kzrp, index;
for (kx = 0; kx < 4; kx++) {
for (ky = 0; ky < 4; ky++) {
for (kz = 0; kz < 4; kz++) {
index = kx * 16 + ky * 4 + kz;
PME_kxyz_cpu[index].uint_x = kx;
PME_kxyz_cpu[index].uint_y = ky;
PME_kxyz_cpu[index].uint_z = kz;
}
}
}
cudaMemcpyAsync(PME_kxyz, PME_kxyz_cpu, sizeof(UNSIGNED_INT_VECTOR) * 64, cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
free(PME_kxyz_cpu);
// initial start
float *B1, *B2, *B3, *PME_BC0;
B1 = reinterpret_cast<float *>(malloc(sizeof(float) * fftx));
B2 = reinterpret_cast<float *>(malloc(sizeof(float) * ffty));
B3 = reinterpret_cast<float *>(malloc(sizeof(float) * fftz));
PME_BC0 = reinterpret_cast<float *>(malloc(sizeof(float) * PME_Nfft));
for (kx = 0; kx < fftx; kx++) {
B1[kx] = getb(kx, fftx, 4);
}
for (ky = 0; ky < ffty; ky++) {
B2[ky] = getb(ky, ffty, 4);
}
for (kz = 0; kz < fftz; kz++) {
B3[kz] = getb(kz, fftz, 4);
}
float mprefactor = PI * PI / -beta / beta;
float msq;
for (kx = 0; kx < fftx; kx++) {
kxrp = kx;
if (kx > fftx / 2) kxrp = fftx - kx;
for (ky = 0; ky < ffty; ky++) {
kyrp = ky;
if (ky > ffty / 2) kyrp = ffty - ky;
for (kz = 0; kz <= fftz / 2; kz++) {
kzrp = kz;
msq = kxrp * kxrp / box_length[0].x / box_length[0].x + kyrp * kyrp / box_length[0].y / box_length[0].y +
kzrp * kzrp / box_length[0].z / box_length[0].z;
index = kx * ffty * (fftz / 2 + 1) + ky * (fftz / 2 + 1) + kz;
if ((kx + ky + kz) == 0) {
PME_BC0[index] = 0;
} else {
PME_BC0[index] = 1.0 / PI / msq * exp(mprefactor * msq) / volume;
}
PME_BC0[index] *= B1[kx] * B2[ky] * B3[kz];
}
}
}
cudaMemcpyAsync(PME_BC, PME_BC0, sizeof(float) * PME_Nfft, cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
free(B1);
free(B2);
free(B3);
free(PME_BC0);
Reset_List<<<3 * atom_numbers / 32 + 1, 32, 0, stream>>>(3 * atom_numbers, reinterpret_cast<int *>(PME_uxyz),
1 << 30);
@ -226,9 +141,3 @@ void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float
d_correction_ene);
return;
}
void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *box_length_f, float *PME_BC,
int *pme_uxyz, float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const int *uint_crd_f, const float *charge, int *nl_atom_numbers, int *nl_atom_serial, int *nl,
const float *scaler_f, const int *excluded_list_start, const int *excluded_list,
const int *excluded_atom_numbers, float *d_reciprocal_ene, float *d_self_ene, float *d_direct_ene,
float *d_correction_ene, cudaStream_t stream);

View File

@ -16,15 +16,15 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_PME_ENERGY_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_PME_ENERGY_IMPL_H_
#include <curand_kernel.h>
#include <vector>
#include <cufft.h>
#include "runtime/device/gpu/cuda_common.h"
void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *box_length_f, float *PME_BC,
int *pme_uxyz, float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const int *uint_crd_f, const float *charge, int *nl_atom_numbers, int *nl_atom_serial, int *nl,
const float *scaler_f, const int *excluded_list_start, const int *excluded_list,
const int *excluded_atom_numbers, float *d_reciprocal_ene, float *d_self_ene, float *d_direct_ene,
float *d_correction_ene, cudaStream_t stream);
void PMEEnergy(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *PME_BC, int *pme_uxyz,
float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz, const int *uint_crd_f,
const float *charge, int *nl_atom_numbers, int *nl_atom_serial, int *nl, const float *scaler_f,
const int *excluded_list_start, const int *excluded_list, const int *excluded_atom_numbers,
float *d_reciprocal_ene, float *d_self_ene, float *d_direct_ene, float *d_correction_ene,
dim3 thread_PME, int PME_Nin, int PME_Nfft, int PME_Nall, const cufftHandle &PME_plan_r2c,
const cufftHandle &PME_plan_c2r, cudaStream_t stream);
#endif

View File

@ -28,7 +28,7 @@ __global__ void PME_BCFQ(cufftComplex *PME_FQ, float *PME_BC, int PME_Nfft) {
__global__ void PME_Final(int *PME_atom_near, const float *charge, const float *PME_Q, VECTOR *force,
const VECTOR *PME_frxyz, const UNSIGNED_INT_VECTOR *PME_kxyz,
const VECTOR PME_inverse_box_vector, const int atom_numbers) {
const _VECTOR PME_inverse_box_vector, const int atom_numbers) {
int atom = blockDim.x * blockIdx.x + threadIdx.x;
if (atom < atom_numbers) {
int k, kx;
@ -73,8 +73,9 @@ __global__ void PME_Final(int *PME_atom_near, const float *charge, const float *
void PMEReciprocalForce(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *PME_BC, int *pme_uxyz,
float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const float *box_length_f, const int *uint_crd_f, const float *charge, float *force,
cudaStream_t stream) {
const int *uint_crd_f, const float *charge, float *force, int PME_Nin, int PME_Nall,
int PME_Nfft, const cufftHandle &PME_plan_r2c, const cufftHandle &PME_plan_c2r,
const _VECTOR &PME_inverse_box_vector, cudaStream_t stream) {
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, force, 0.);
UNSIGNED_INT_VECTOR *uint_crd =
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
@ -86,98 +87,8 @@ void PMEReciprocalForce(int fftx, int ffty, int fftz, int atom_numbers, float be
VECTOR *PME_frxyz = reinterpret_cast<VECTOR *>(pme_frxyz);
VECTOR *frc = reinterpret_cast<VECTOR *>(force);
std::vector<float> h_box_length(3);
cudaMemcpyAsync(h_box_length.data(), box_length_f, sizeof(float) * h_box_length.size(), cudaMemcpyDeviceToHost,
stream);
cudaStreamSynchronize(stream);
VECTOR *box_length = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(h_box_length.data()));
cufftComplex *PME_FQ = reinterpret_cast<cufftComplex *>(pme_fq);
VECTOR PME_inverse_box_vector;
PME_inverse_box_vector.x = static_cast<float>(fftx) / box_length[0].x;
PME_inverse_box_vector.y = static_cast<float>(ffty) / box_length[0].y;
PME_inverse_box_vector.z = static_cast<float>(fftz) / box_length[0].z;
cufftHandle PME_plan_r2c;
cufftHandle PME_plan_c2r;
cufftPlan3d(&PME_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
cufftPlan3d(&PME_plan_c2r, fftx, ffty, fftz, CUFFT_C2R);
cufftSetStream(PME_plan_r2c, stream);
cufftSetStream(PME_plan_c2r, stream);
thread_PME.x = 8;
thread_PME.y = 8;
int PME_Nin = ffty * fftz;
int PME_Nfft = fftx * ffty * (fftz / 2 + 1);
int PME_Nall = fftx * ffty * fftz;
float volume = box_length[0].x * box_length[0].y * box_length[0].z;
UNSIGNED_INT_VECTOR *PME_kxyz_cpu;
Malloc_Safely(reinterpret_cast<void **>(&PME_kxyz_cpu), sizeof(UNSIGNED_INT_VECTOR) * 64);
int kx, ky, kz, kxrp, kyrp, kzrp, index;
for (kx = 0; kx < 4; kx++) {
for (ky = 0; ky < 4; ky++) {
for (kz = 0; kz < 4; kz++) {
index = kx * 16 + ky * 4 + kz;
PME_kxyz_cpu[index].uint_x = kx;
PME_kxyz_cpu[index].uint_y = ky;
PME_kxyz_cpu[index].uint_z = kz;
}
}
}
cudaMemcpyAsync(PME_kxyz, PME_kxyz_cpu, sizeof(UNSIGNED_INT_VECTOR) * 64, cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
free(PME_kxyz_cpu);
// initial start
float *B1, *B2, *B3, *PME_BC0;
B1 = reinterpret_cast<float *>(malloc(sizeof(float) * fftx));
B2 = reinterpret_cast<float *>(malloc(sizeof(float) * ffty));
B3 = reinterpret_cast<float *>(malloc(sizeof(float) * fftz));
PME_BC0 = reinterpret_cast<float *>(malloc(sizeof(float) * PME_Nfft));
for (kx = 0; kx < fftx; kx++) {
B1[kx] = getb(kx, fftx, 4);
}
for (ky = 0; ky < ffty; ky++) {
B2[ky] = getb(ky, ffty, 4);
}
for (kz = 0; kz < fftz; kz++) {
B3[kz] = getb(kz, fftz, 4);
}
float mprefactor = PI * PI / -beta / beta;
float msq;
for (kx = 0; kx < fftx; kx++) {
kxrp = kx;
if (kx > fftx / 2) kxrp = fftx - kx;
for (ky = 0; ky < ffty; ky++) {
kyrp = ky;
if (ky > ffty / 2) kyrp = ffty - ky;
for (kz = 0; kz <= fftz / 2; kz++) {
kzrp = kz;
msq = kxrp * kxrp / box_length[0].x / box_length[0].x + kyrp * kyrp / box_length[0].y / box_length[0].y +
kzrp * kzrp / box_length[0].z / box_length[0].z;
index = kx * ffty * (fftz / 2 + 1) + ky * (fftz / 2 + 1) + kz;
if ((kx + ky + kz) == 0) {
PME_BC0[index] = 0;
} else {
PME_BC0[index] = 1.0 / PI / msq * exp(mprefactor * msq) / volume;
}
PME_BC0[index] *= B1[kx] * B2[ky] * B3[kz];
}
}
}
cudaMemcpyAsync(PME_BC, PME_BC0, sizeof(float) * PME_Nfft, cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
free(B1);
free(B2);
free(B3);
free(PME_BC0);
// initial end
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(
3 * atom_numbers, reinterpret_cast<float *>(frc), 0.);
@ -198,8 +109,3 @@ void PMEReciprocalForce(int fftx, int ffty, int fftz, int atom_numbers, float be
PME_kxyz, PME_inverse_box_vector, atom_numbers);
return;
}
void PMEReciprocalForce(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *PME_BC, int *pme_uxyz,
float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const float *box_length_f, const int *uint_crd_f, const float *charge, float *force,
cudaStream_t stream);

View File

@ -16,13 +16,18 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_PME_RECIPROCAL_FORCE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_PME_RECIPROCAL_FORCE_IMPL_H_
#include <curand_kernel.h>
#include <vector>
#include <cufft.h>
#include "runtime/device/gpu/cuda_common.h"
struct _VECTOR {
float x;
float y;
float z;
};
void PMEReciprocalForce(int fftx, int ffty, int fftz, int atom_numbers, float beta, float *PME_BC, int *pme_uxyz,
float *pme_frxyz, float *PME_Q, float *pme_fq, int *PME_atom_near, int *pme_kxyz,
const float *box_length_f, const int *uint_crd_f, const float *charge, float *force,
cudaStream_t stream);
const int *uint_crd_f, const float *charge, float *force, int PME_Nin, int PME_Nall,
int PME_Nfft, const cufftHandle &PME_plan_r2c, const cufftHandle &PME_plan_c2r,
const _VECTOR &PME_inverse_box_vector, cudaStream_t stream);
#endif

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/common/crd_to_uint_crd_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
CrdToUintCrd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
CrdToUintCrdGpuKernel, float, unsigned int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_CRD_TO_UINT_CRD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_CRD_TO_UINT_CRD_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/crd_to_uint_crd_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename T1>
class CrdToUintCrdGpuKernel : public GpuKernel {
public:
CrdToUintCrdGpuKernel() : ele_crd(1) {}
~CrdToUintCrdGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
auto shape_crd_to_uint_crd_cof = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < shape_crd_to_uint_crd_cof.size(); i++)
ele_crd_to_uint_crd_cof *= shape_crd_to_uint_crd_cof[i];
for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i];
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto crd_to_uint_crd_cof = GetDeviceAddress<const T>(inputs, 0);
auto crd = GetDeviceAddress<const T>(inputs, 1);
auto uint_crd = GetDeviceAddress<T1>(outputs, 0);
CrdToUintCrd(atom_numbers, crd_to_uint_crd_cof, crd, uint_crd, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(ele_crd_to_uint_crd_cof * sizeof(T));
input_size_list_.push_back(ele_crd * sizeof(T));
output_size_list_.push_back(3 * atom_numbers * sizeof(T));
}
private:
size_t ele_crd_to_uint_crd_cof = 1;
size_t ele_crd = 1;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int atom_numbers;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_CRD_TO_UINT_CRD_KERNEL_H_

View File

@ -38,6 +38,7 @@ MS_REG_GPU_KERNEL_TWO(NeighborListUpdate,
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
NeighborListUpdateGpuKernel, int, float)

View File

@ -36,7 +36,6 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
grid_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "grid_numbers"));
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
refresh_count = static_cast<int>(GetAttr<int64_t>(kernel_node, "refresh_count"));
refresh_interval = static_cast<int>(GetAttr<int64_t>(kernel_node, "refresh_interval"));
not_first_time = static_cast<int>(GetAttr<int64_t>(kernel_node, "not_first_time"));
Nxy = static_cast<int>(GetAttr<int64_t>(kernel_node, "Nxy"));
@ -47,7 +46,8 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
cutoff_with_skin = static_cast<float>(GetAttr<float>(kernel_node, "cutoff_with_skin"));
half_cutoff_with_skin = static_cast<float>(GetAttr<float>(kernel_node, "half_cutoff_with_skin"));
cutoff_with_skin_square = static_cast<float>(GetAttr<float>(kernel_node, "cutoff_with_skin_square"));
h_bucket.resize(grid_numbers);
h_gpointer.resize(grid_numbers);
InitSizeLists();
return true;
}
@ -76,17 +76,18 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
auto excluded_list = GetDeviceAddress<int>(inputs, 15);
auto excluded_numbers = GetDeviceAddress<int>(inputs, 16);
auto need_refresh_flag = GetDeviceAddress<int>(inputs, 17);
auto d_refresh_count = GetDeviceAddress<int>(inputs, 18);
GRID_BUCKET *d_bucket = reinterpret_cast<GRID_BUCKET *>(GetDeviceAddress<int>(workspaces, 0));
GRID_POINTER *d_gpointer = reinterpret_cast<GRID_POINTER *>(GetDeviceAddress<int>(workspaces, 1));
NEIGHBOR_LIST *nl = GetDeviceAddress<NEIGHBOR_LIST>(workspaces, 2);
float *half_crd_to_uint_crd_cof = GetDeviceAddress<float>(workspaces, 3);
std::vector<GRID_BUCKET> h_bucket(grid_numbers);
// std::vector<GRID_BUCKET> h_bucket(grid_numbers);
for (size_t i = 0; i < h_bucket.size(); i += 1) {
h_bucket[i].atom_serial = bucket + i * max_atom_in_grid_numbers;
}
std::vector<GRID_POINTER> h_gpointer(grid_numbers);
// std::vector<GRID_POINTER> h_gpointer(grid_numbers);
for (size_t i = 0; i < h_gpointer.size(); i += 1) {
h_gpointer[i].grid_serial = gpointer + i * 125;
}
@ -98,7 +99,7 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
Construct_Neighbor_List(atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl,
reinterpret_cast<cudaStream_t>(stream_ptr));
Neighbor_List_Update(grid_numbers, atom_numbers, refresh_count, refresh_interval, not_first_time, skin, Nxy,
Neighbor_List_Update(grid_numbers, atom_numbers, d_refresh_count, refresh_interval, not_first_time, skin, Nxy,
cutoff_square, cutoff_with_skin_square, grid_N, box_length, atom_numbers_in_grid_bucket,
grid_length_inverse, atom_in_grid_serial, d_bucket, crd, old_crd, crd_to_uint_crd_cof,
half_crd_to_uint_crd_cof, uint_crd, uint_dr_to_dr_cof, d_gpointer, nl, excluded_list_start,
@ -132,6 +133,7 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
input_size_list_.push_back(sizeof(int) * excluded_atom_numbers);
input_size_list_.push_back(sizeof(int) * atom_numbers);
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
workspace_size_list_.push_back(sizeof(GRID_BUCKET) * grid_numbers);
@ -148,7 +150,6 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
int not_first_time;
int atom_numbers;
int grid_numbers;
int refresh_count;
int refresh_interval;
int Nxy;
int max_atom_in_grid_numbers;
@ -163,6 +164,8 @@ class NeighborListUpdateGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<GRID_BUCKET> h_bucket;
std::vector<GRID_POINTER> h_gpointer;
};
} // namespace kernel
} // namespace mindspore

View File

@ -45,14 +45,6 @@ class MDIterationLeapFrogGpuKernel : public GpuKernel {
is_max_velocity = static_cast<int>(GetAttr<int64_t>(kernel_node, "is_max_velocity"));
max_velocity = static_cast<float>(GetAttr<float>(kernel_node, "max_velocity"));
// printf("float4_numbers: %d", float4_numbers);
// printf("atom_numbers: %d", atom_numbers);
// printf("half_dt: %f", half_dt);
// printf("dt: %f", dt);
// printf("exp_gamma: %f", exp_gamma);
// printf("is_max_velocity: %d", is_max_velocity);
// printf("max_velocity: %f", max_velocity);
auto shape_mass_inverse = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape_qrt_mass = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/nvtit/md_iteration_leap_frog_liujian_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(MDIterationLeapFrogLiujian,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MDIterationLeapFrogLiujianCudaGpuKernel, float, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MD_ITERATION_LEAP_FROG_LIUJIAN_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MD_ITERATION_LEAP_FROG_LIUJIAN_GPU_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nvtit/md_iteration_leap_frog_liujian_gpu_impl.cuh"
#include <cuda_runtime_api.h>
#include <map>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
namespace mindspore {
namespace kernel {
template <typename T, typename T1>
class MDIterationLeapFrogLiujianCudaGpuKernel : public GpuKernel {
public:
MDIterationLeapFrogLiujianCudaGpuKernel() {}
~MDIterationLeapFrogLiujianCudaGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
// get bond_numbers
kernel_node_ = kernel_node;
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
half_dt = static_cast<float>(GetAttr<float>(kernel_node, "half_dt"));
dt = static_cast<float>(GetAttr<float>(kernel_node, "dt"));
exp_gamma = static_cast<float>(GetAttr<float>(kernel_node, "exp_gamma"));
float4_numbers = ceil(3. * static_cast<double>(atom_numbers) / 4.);
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto inverse_mass = GetDeviceAddress<float>(inputs, 0);
auto sqrt_mass_inverse = GetDeviceAddress<float>(inputs, 1);
auto vel = GetDeviceAddress<float>(inputs, 2);
auto crd = GetDeviceAddress<float>(inputs, 3);
auto frc = GetDeviceAddress<float>(inputs, 4);
auto acc = GetDeviceAddress<float>(inputs, 5);
auto rand_state = GetDeviceAddress<float>(inputs, 6);
auto rand_frc = GetDeviceAddress<float>(inputs, 7);
auto output = GetDeviceAddress<float>(outputs, 0);
MD_Iteration_Leap_Frog_With_LiuJian(atom_numbers, half_dt, dt, exp_gamma, float4_numbers, inverse_mass,
sqrt_mass_inverse, vel, crd, frc, acc,
reinterpret_cast<curandStatePhilox4_32_10_t *>(rand_state), rand_frc, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(atom_numbers * sizeof(float));
input_size_list_.push_back(atom_numbers * sizeof(float));
input_size_list_.push_back(atom_numbers * 3 * sizeof(float));
input_size_list_.push_back(atom_numbers * 3 * sizeof(float));
input_size_list_.push_back(atom_numbers * 3 * sizeof(float));
input_size_list_.push_back(atom_numbers * 3 * sizeof(float));
input_size_list_.push_back(float4_numbers * sizeof(curandStatePhilox4_32_10_t));
input_size_list_.push_back(atom_numbers * 3 * sizeof(float));
output_size_list_.push_back(atom_numbers * 3 * sizeof(T));
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int atom_numbers;
float half_dt;
float dt;
float exp_gamma;
int float4_numbers;
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/nvtit/md_iteration_setup_random_state.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(MDIterationSetupRandState, KernelAttr().AddOutputAttr(kNumberTypeFloat32),
MDIterationSetupRandStateGpuKernel, float, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MD_ITERATION_SETUP_RANDOM_STATE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MD_ITERATION_SETUP_RANDOM_STATE_GPU_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/nvtit/md_iteration_setup_random_state_gpu_impl.cuh"
#include <cuda_runtime_api.h>
#include <map>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
namespace mindspore {
namespace kernel {
template <typename T, typename T1>
class MDIterationSetupRandStateGpuKernel : public GpuKernel {
public:
MDIterationSetupRandStateGpuKernel() {}
~MDIterationSetupRandStateGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
// get bond_numbers
kernel_node_ = kernel_node;
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
float4_numbers = ceil(3. * static_cast<double>(atom_numbers) / 4.);
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto output = GetDeviceAddress<float>(outputs, 0);
curandStatePhilox4_32_10_t *rand_state = reinterpret_cast<curandStatePhilox4_32_10_t *>(output);
MD_Iteration_Setup_Random_State(float4_numbers, rand_state, seed, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override { output_size_list_.push_back(sizeof(curandStatePhilox4_32_10_t) * float4_numbers); }
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int atom_numbers;
int seed;
int float4_numbers;
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -19,7 +19,6 @@ namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(PMEEnergy,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)

View File

@ -18,8 +18,6 @@
#include <cuda_runtime_api.h>
#include <cufft.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
@ -40,8 +38,76 @@ class PMEEnergyGpuKernel : public GpuKernel {
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
PME_Nall = fftx * ffty * fftz;
float box_length_0 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_0"));
float box_length_1 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_1"));
float box_length_2 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_2"));
std::vector<float> h_box_length(3);
h_box_length[0] = box_length_0;
h_box_length[1] = box_length_1;
h_box_length[2] = box_length_2;
VECTOR *box_length = reinterpret_cast<VECTOR *>(h_box_length.data());
cufftPlan3d(&PME_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
cufftPlan3d(&PME_plan_c2r, fftx, ffty, fftz, CUFFT_C2R);
_thread_PME.x = 8;
_thread_PME.y = 8;
PME_Nin = ffty * fftz;
PME_Nfft = fftx * ffty * (fftz / 2 + 1);
PME_Nall = fftx * ffty * fftz;
PME_kxyz_cpu.resize(64);
volume = box_length[0].x * box_length[0].y * box_length[0].z;
int kx, ky, kz, kxrp, kyrp, kzrp, index;
for (kx = 0; kx < 4; kx++) {
for (ky = 0; ky < 4; ky++) {
for (kz = 0; kz < 4; kz++) {
index = kx * 16 + ky * 4 + kz;
PME_kxyz_cpu[index].uint_x = kx;
PME_kxyz_cpu[index].uint_y = ky;
PME_kxyz_cpu[index].uint_z = kz;
}
}
}
B1.resize(fftx);
B2.resize(ffty);
B3.resize(fftz);
PME_BC0.resize(PME_Nfft);
for (kx = 0; kx < fftx; kx++) {
B1[kx] = getb(kx, fftx, 4);
}
for (ky = 0; ky < ffty; ky++) {
B2[ky] = getb(ky, ffty, 4);
}
for (kz = 0; kz < fftz; kz++) {
B3[kz] = getb(kz, fftz, 4);
}
float mprefactor = PI * PI / -beta / beta;
float msq;
for (kx = 0; kx < fftx; kx++) {
kxrp = kx;
if (kx > fftx / 2) kxrp = fftx - kx;
for (ky = 0; ky < ffty; ky++) {
kyrp = ky;
if (ky > ffty / 2) kyrp = ffty - ky;
for (kz = 0; kz <= fftz / 2; kz++) {
kzrp = kz;
msq = kxrp * kxrp / box_length[0].x / box_length[0].x + kyrp * kyrp / box_length[0].y / box_length[0].y +
kzrp * kzrp / box_length[0].z / box_length[0].z;
index = kx * ffty * (fftz / 2 + 1) + ky * (fftz / 2 + 1) + kz;
if ((kx + ky + kz) == 0) {
PME_BC0[index] = 0;
} else {
PME_BC0[index] = 1.0 / PI / msq * exp(mprefactor * msq) / volume;
}
PME_BC0[index] *= B1[kx] * B2[ky] * B3[kz];
}
}
}
InitSizeLists();
return true;
@ -53,15 +119,14 @@ class PMEEnergyGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto boxlength = GetDeviceAddress<T>(inputs, 0);
auto uint_crd = GetDeviceAddress<T1>(inputs, 1);
auto charge = GetDeviceAddress<T>(inputs, 2);
auto nl_numbers = GetDeviceAddress<T1>(inputs, 3);
auto nl_serial = GetDeviceAddress<T1>(inputs, 4);
auto scaler = GetDeviceAddress<T>(inputs, 5);
auto excluded_list_start = GetDeviceAddress<int>(inputs, 6);
auto excluded_list = GetDeviceAddress<int>(inputs, 7);
auto excluded_atom_numbers = GetDeviceAddress<int>(inputs, 8);
auto uint_crd = GetDeviceAddress<T1>(inputs, 0);
auto charge = GetDeviceAddress<T>(inputs, 1);
auto nl_numbers = GetDeviceAddress<T1>(inputs, 2);
auto nl_serial = GetDeviceAddress<T1>(inputs, 3);
auto scaler = GetDeviceAddress<T>(inputs, 4);
auto excluded_list_start = GetDeviceAddress<int>(inputs, 5);
auto excluded_list = GetDeviceAddress<int>(inputs, 6);
auto excluded_atom_numbers = GetDeviceAddress<int>(inputs, 7);
auto pme_uxyz = GetDeviceAddress<int>(workspace, 0); // workspace
auto pme_frxyz = GetDeviceAddress<float>(workspace, 1); // workspace
@ -77,16 +142,22 @@ class PMEEnergyGpuKernel : public GpuKernel {
auto direct_ene = GetDeviceAddress<T>(outputs, 2);
auto correction_ene = GetDeviceAddress<T>(outputs, 3);
PMEEnergy(fftx, ffty, fftz, atom_numbers, beta, boxlength, pme_bc, pme_uxyz, pme_frxyz, pme_q, pme_fq,
pme_atom_near, pme_kxyz, uint_crd, charge, nl_numbers, nl_serial, nl, scaler, excluded_list_start,
excluded_list, excluded_atom_numbers, reciprocal_ene, self_ene, direct_ene, correction_ene,
reinterpret_cast<cudaStream_t>(stream_ptr));
cufftSetStream(PME_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
cufftSetStream(PME_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(pme_kxyz, PME_kxyz_cpu.data(), sizeof(UNSIGNED_INT_VECTOR) * 64, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(pme_bc, PME_BC0.data(), sizeof(float) * PME_Nfft, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
PMEEnergy(fftx, ffty, fftz, atom_numbers, beta, pme_bc, pme_uxyz, pme_frxyz, pme_q, pme_fq, pme_atom_near, pme_kxyz,
uint_crd, charge, nl_numbers, nl_serial, nl, scaler, excluded_list_start, excluded_list,
excluded_atom_numbers, reciprocal_ene, self_ene, direct_ene, correction_ene, _thread_PME, PME_Nin,
PME_Nfft, PME_Nall, PME_plan_r2c, PME_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(sizeof(VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(UNSIGNED_INT_VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(T1));
@ -112,12 +183,56 @@ class PMEEnergyGpuKernel : public GpuKernel {
output_size_list_.push_back(sizeof(T));
}
cufftComplex expc(cufftComplex z) {
cufftComplex res;
float t = expf(z.x);
sincosf(z.y, &res.y, &res.x);
res.x *= t;
res.y *= t;
return res;
}
float M_(float u, int n) {
if (n == 2) {
if (u > 2 || u < 0) return 0;
return 1 - abs(u - 1);
} else {
return u / (n - 1) * M_(u, n - 1) + (n - u) / (n - 1) * M_(u - 1, n - 1);
}
}
float getb(int k, int NFFT, int B_order) {
cufftComplex tempc, tempc2, res;
float tempf;
tempc2.x = 0;
tempc2.y = 0;
tempc.x = 0;
tempc.y = 2 * (B_order - 1) * PI * k / NFFT;
res = expc(tempc);
for (int kk = 0; kk < (B_order - 1); kk++) {
tempc.x = 0;
tempc.y = 2 * PI * k / NFFT * kk;
tempc = expc(tempc);
tempf = M_(kk + 1, B_order);
tempc2.x += tempf * tempc.x;
tempc2.y += tempf * tempc.y;
}
res = cuCdivf(res, tempc2);
return res.x * res.x + res.y * res.y;
}
private:
size_t ele_uint_crd = 1;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<float> B1;
std::vector<float> B2;
std::vector<float> B3;
std::vector<float> PME_BC0;
int atom_numbers;
int excluded_numbers;
int max_nl_numbers = 800;
@ -125,8 +240,16 @@ class PMEEnergyGpuKernel : public GpuKernel {
int ffty;
int fftz;
float beta;
int PME_Nin;
int PME_Nall;
int PME_Nfft;
float volume;
float PI = 3.1415926;
cufftHandle PME_plan_r2c;
cufftHandle PME_plan_c2r;
dim3 _thread_PME;
struct VECTOR {
float x;
float y;
@ -138,7 +261,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
unsigned int uint_y;
unsigned int uint_z;
};
std::vector<UNSIGNED_INT_VECTOR> PME_kxyz_cpu;
struct NEIGHBOR_LIST {
int atom_numbers;
int *atom_serial;

View File

@ -17,13 +17,10 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(PMEReciprocalForce,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PMEReciprocalForceGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
PMEReciprocalForce,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PMEReciprocalForceGpuKernel, float, int)
} // namespace kernel
} // namespace mindspore

View File

@ -41,6 +41,75 @@ class PMEReciprocalForceGpuKernel : public GpuKernel {
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
PME_Nall = fftx * ffty * fftz;
PME_Nfft = fftx * ffty * (fftz / 2 + 1);
PME_Nin = ffty * fftz;
float box_length_0 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_0"));
float box_length_1 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_1"));
float box_length_2 = static_cast<float>(GetAttr<float_t>(kernel_node, "box_length_2"));
std::vector<float> h_box_length(3);
h_box_length[0] = box_length_0;
h_box_length[1] = box_length_1;
h_box_length[2] = box_length_2;
VECTOR *box_length = reinterpret_cast<VECTOR *>(h_box_length.data());
PME_inverse_box_vector.x = static_cast<float>(fftx) / box_length[0].x;
PME_inverse_box_vector.y = static_cast<float>(ffty) / box_length[0].y;
PME_inverse_box_vector.z = static_cast<float>(fftz) / box_length[0].z;
cufftPlan3d(&PME_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
cufftPlan3d(&PME_plan_c2r, fftx, ffty, fftz, CUFFT_C2R);
float volume = box_length[0].x * box_length[0].y * box_length[0].z;
PME_kxyz_cpu.resize(64);
int kx, ky, kz, kxrp, kyrp, kzrp, index;
for (kx = 0; kx < 4; kx++) {
for (ky = 0; ky < 4; ky++) {
for (kz = 0; kz < 4; kz++) {
index = kx * 16 + ky * 4 + kz;
PME_kxyz_cpu[index].uint_x = kx;
PME_kxyz_cpu[index].uint_y = ky;
PME_kxyz_cpu[index].uint_z = kz;
}
}
}
B1.resize(fftx);
B2.resize(ffty);
B3.resize(fftz);
PME_BC0.resize(PME_Nfft);
for (kx = 0; kx < fftx; kx++) {
B1[kx] = getb(kx, fftx, 4);
}
for (ky = 0; ky < ffty; ky++) {
B2[ky] = getb(ky, ffty, 4);
}
for (kz = 0; kz < fftz; kz++) {
B3[kz] = getb(kz, fftz, 4);
}
float mprefactor = PI * PI / -beta / beta;
float msq;
for (kx = 0; kx < fftx; kx++) {
kxrp = kx;
if (kx > fftx / 2) kxrp = fftx - kx;
for (ky = 0; ky < ffty; ky++) {
kyrp = ky;
if (ky > ffty / 2) kyrp = ffty - ky;
for (kz = 0; kz <= fftz / 2; kz++) {
kzrp = kz;
msq = kxrp * kxrp / box_length[0].x / box_length[0].x + kyrp * kyrp / box_length[0].y / box_length[0].y +
kzrp * kzrp / box_length[0].z / box_length[0].z;
index = kx * ffty * (fftz / 2 + 1) + ky * (fftz / 2 + 1) + kz;
if ((kx + ky + kz) == 0) {
PME_BC0[index] = 0;
} else {
PME_BC0[index] = 1.0 / PI / msq * exp(mprefactor * msq) / volume;
}
PME_BC0[index] *= B1[kx] * B2[ky] * B3[kz];
}
}
}
InitSizeLists();
return true;
@ -52,9 +121,8 @@ class PMEReciprocalForceGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto boxlength = GetDeviceAddress<T>(inputs, 0);
auto uint_crd = GetDeviceAddress<const T1>(inputs, 1);
auto charge = GetDeviceAddress<T>(inputs, 2);
auto uint_crd = GetDeviceAddress<const T1>(inputs, 0);
auto charge = GetDeviceAddress<T>(inputs, 1);
auto pme_uxyz = GetDeviceAddress<int>(workspace, 0); // workspace
auto pme_frxyz = GetDeviceAddress<float>(workspace, 1); // workspace
@ -65,9 +133,15 @@ class PMEReciprocalForceGpuKernel : public GpuKernel {
auto pme_kxyz = GetDeviceAddress<int>(workspace, 6); // workspace
auto force = GetDeviceAddress<T>(outputs, 0);
cufftSetStream(PME_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
cufftSetStream(PME_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(pme_kxyz, PME_kxyz_cpu.data(), sizeof(UNSIGNED_INT_VECTOR) * 64, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemcpyAsync(pme_bc, PME_BC0.data(), sizeof(float) * PME_Nfft, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
PMEReciprocalForce(fftx, ffty, fftz, atom_numbers, beta, pme_bc, pme_uxyz, pme_frxyz, pme_q, pme_fq, pme_atom_near,
pme_kxyz, boxlength, uint_crd, charge, force, reinterpret_cast<cudaStream_t>(stream_ptr));
pme_kxyz, uint_crd, charge, force, PME_Nin, PME_Nall, PME_Nfft, PME_plan_r2c, PME_plan_c2r,
PME_inverse_box_vector, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -88,6 +162,44 @@ class PMEReciprocalForceGpuKernel : public GpuKernel {
output_size_list_.push_back(atom_numbers * sizeof(VECTOR));
}
cufftComplex expc(cufftComplex z) {
cufftComplex res;
float t = expf(z.x);
sincosf(z.y, &res.y, &res.x);
res.x *= t;
res.y *= t;
return res;
}
float M_(float u, int n) {
if (n == 2) {
if (u > 2 || u < 0) return 0;
return 1 - abs(u - 1);
} else {
return u / (n - 1) * M_(u, n - 1) + (n - u) / (n - 1) * M_(u - 1, n - 1);
}
}
float getb(int k, int NFFT, int B_order) {
cufftComplex tempc, tempc2, res;
float tempf;
tempc2.x = 0;
tempc2.y = 0;
tempc.x = 0;
tempc.y = 2 * (B_order - 1) * PI * k / NFFT;
res = expc(tempc);
for (int kk = 0; kk < (B_order - 1); kk++) {
tempc.x = 0;
tempc.y = 2 * PI * k / NFFT * kk;
tempc = expc(tempc);
tempf = M_(kk + 1, B_order);
tempc2.x += tempf * tempc.x;
tempc2.y += tempf * tempc.y;
}
res = cuCdivf(res, tempc2);
return res.x * res.x + res.y * res.y;
}
private:
size_t ele_uint_crd = 1;
@ -101,18 +213,27 @@ class PMEReciprocalForceGpuKernel : public GpuKernel {
float beta;
int PME_Nall;
int PME_Nfft;
int PME_Nin;
float PI = 3.1415926;
std::vector<float> B1;
std::vector<float> B2;
std::vector<float> B3;
std::vector<float> PME_BC0;
cufftHandle PME_plan_r2c;
cufftHandle PME_plan_c2r;
struct VECTOR {
float x;
float y;
float z;
};
_VECTOR PME_inverse_box_vector;
struct UNSIGNED_INT_VECTOR {
unsigned int uint_x;
unsigned int uint_y;
unsigned int uint_z;
};
std::vector<UNSIGNED_INT_VECTOR> PME_kxyz_cpu;
};
} // namespace kernel
} // namespace mindspore

View File

@ -105,7 +105,8 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
LJForce, LJEnergy, LJForceWithPMEDirectForce, PMEExcludedForce, PMEEnergy, Dihedral14LJForce,
Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy, MDIterationLeapFrog,
GetCenterOfGeometry, MDTemperature, NeighborListUpdate)
GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian,
CrdToUintCrd, MDIterationSetupRandState)
__all__ = [
@ -465,7 +466,9 @@ __all__ = [
"GetCenterOfGeometry",
"MDTemperature",
"NeighborListUpdate",
"MDIterationLeapFrogLiujian",
"CrdToUintCrd",
"MDIterationSetupRandState",
]
__all__.sort()

File diff suppressed because it is too large Load Diff

View File

@ -12,44 +12,78 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main"""
import time
'''main'''
import argparse
from mindspore import context
from src.simulation_initial import Simulation
import time
from src.simulation import Simulation
import mindspore.context as context
from mindspore import Tensor
parser = argparse.ArgumentParser(description='Sponge Controller')
parser.add_argument('--i', type=str, default=None, help='input file')
parser.add_argument('--amber_parm', type=str, default=None,
help='paramter file in AMBER type')
parser.add_argument('--c', type=str, default=None,
help='initial coordinates file')
parser.add_argument('--amber_parm', type=str, default=None, help='paramter file in AMBER type')
parser.add_argument('--c', type=str, default=None, help='initial coordinates file')
parser.add_argument('--r', type=str, default="restrt", help='')
parser.add_argument('--x', type=str, default="mdcrd", help='')
parser.add_argument('--o', type=str, default="mdout", help="")
parser.add_argument('--box', type=str, default="mdbox", help='')
parser.add_argument('--device_id', type=int, default=0, help='')
args_opt = parser.parse_args()
context.set_context(mode=context.PYNATIVE_MODE,
device_target="GPU", device_id=0, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
if __name__ == "__main__":
start = time.time()
simulation = Simulation(args_opt)
simulation.Main_Initial()
res = simulation.Initial_Neighbor_List_Update(not_first_time=0)
md_info = simulation.md_info
md_info.step_limit = 1
for i in range(1, md_info.step_limit + 1):
print("steps: ", i)
md_info.steps = i
simulation.Main_Before_Calculate_Force()
simulation.Main_Calculate_Force()
simulation.Main_Calculate_Energy()
simulation.Main_After_Calculate_Energy()
temperature = simulation.Main_Print()
simulation.Main_Iteration_2()
start = time.time()
compiler_time = 0
save_path = args_opt.o
file = open(save_path, 'w')
for steps in range(simulation.md_info.step_limit):
print_step = steps % simulation.ntwx
if steps == simulation.md_info.step_limit - 1:
print_step = 0
temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step))
if steps == 0:
compiler_time = time.time()
if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1:
if steps == 0:
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
temperature = temperature.asnumpy()
total_potential_energy = total_potential_energy.asnumpy()
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)),
end=" ")
if simulation.bond.bond_numbers > 0:
sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
if simulation.angle.angle_numbers > 0:
sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
if simulation.dihedral.dihedral_numbers > 0:
sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
if simulation.nb14.nb14_numbers > 0:
nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
LJ_energy_sum = LJ_energy_sum.asnumpy()
ee_ene = ee_ene.asnumpy()
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
print("{:>12.3f}".format(float(ee_ene)))
if file is not None:
file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
end = time.time()
file.close()
print("Main time(s):", end - start)
simulation.Main_Destroy()
print("Main time(s) without compiler:", end - compiler_time)

View File

@ -12,31 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""angle class"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Angle(nn.Cell):
"""Angle class"""
'''Angle'''
class Angle:
'''Angle'''
def __init__(self, controller):
super(Angle, self).__init__()
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
self.angle_k = Tensor(np.asarray(self.h_angle_k, np.float32), mstype.float32)
self.angle_theta0 = Tensor(np.asarray(self.h_angle_theta0, np.float32), mstype.float32)
def read_process1(self, context):
"""read_information_from_amberfile process1"""
def read_information_from_amberfile(self, file_path):
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
@ -46,6 +34,7 @@ class Angle(nn.Cell):
self.angle_with_H_numbers = value[4]
self.angle_without_H_numbers = value[5]
self.angle_numbers = self.angle_with_H_numbers + self.angle_without_H_numbers
# print(self.angle_numbers)
information = []
information.extend(value)
while count < 15:
@ -57,8 +46,10 @@ class Angle(nn.Cell):
print("angle type numbers ", self.angle_type_numbers)
break
def read_process2(self, context):
"""read_information_from_amberfile process2"""
self.h_atom_a = [0] * self.angle_numbers
self.h_atom_b = [0] * self.angle_numbers
self.h_atom_c = [0] * self.angle_numbers
self.h_type = [0] * self.angle_numbers
angle_count = 0
for idx, val in enumerate(context):
if "%FLAG ANGLES_INC_HYDROGEN" in val:
@ -81,20 +72,6 @@ class Angle(nn.Cell):
angle_count += 1
break
return angle_count
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.read_process1(context)
self.h_atom_a = [0] * self.angle_numbers
self.h_atom_b = [0] * self.angle_numbers
self.h_atom_c = [0] * self.angle_numbers
self.h_type = [0] * self.angle_numbers
angle_count = self.read_process2(context)
for idx, val in enumerate(context):
if "%FLAG ANGLES_WITHOUT_HYDROGEN" in val:
@ -109,14 +86,17 @@ class Angle(nn.Cell):
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.angle_without_H_numbers):
for _ in range(self.angle_without_H_numbers):
self.h_atom_a[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 0] / 3
self.h_atom_b[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 1] / 3
self.h_atom_c[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 2] / 3
self.h_type[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 3] - 1
angle_count += 1
break
self.processor(context, angle_count)
def processor(self, context, angle_count):
''' processor '''
self.type_k = [0] * self.angle_type_numbers
for idx, val in enumerate(context):
if "%FLAG ANGLE_FORCE_CONSTANT" in val:
@ -159,17 +139,3 @@ class Angle(nn.Cell):
for i in range(self.angle_numbers):
self.h_angle_k.append(self.type_k[self.h_type[i]])
self.h_angle_theta0.append(self.type_theta0[self.h_type[i]])
def Angle_Energy(self, uint_crd, uint_dr_to_dr_cof):
"""compute angle energy"""
self.angle_energy = P.AngleEnergy(self.angle_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a, self.atom_b,
self.atom_c, self.angle_k, self.angle_theta0)
self.sigma_of_angle_ene = P.ReduceSum()(self.angle_energy)
return self.sigma_of_angle_ene
def Angle_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute angle force with atom energy"""
print("angele angle numbers:", self.angle_numbers)
self.afae = P.AngleForceWithAtomEnergy(angle_numbers=self.angle_numbers)
frc, ene = self.afae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.angle_k, self.angle_theta0)
return frc, ene

View File

@ -12,19 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""bond class"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Bond(nn.Cell):
"""bond class"""
'''Bond'''
class Bond:
'''Bond'''
def __init__(self, controller, md_info):
super(Bond, self).__init__()
self.atom_numbers = md_info.atom_numbers
@ -32,13 +23,11 @@ class Bond(nn.Cell):
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.bond_k = Tensor(np.asarray(self.h_k, np.float32), mstype.float32)
self.bond_r0 = Tensor(np.asarray(self.h_r0, np.float32), mstype.float32)
def process1(self, context):
"""process1: read information from amberfile"""
def read_information_from_amberfile(self, file_path):
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
@ -48,7 +37,6 @@ class Bond(nn.Cell):
self.bond_with_hydrogen = value[2]
self.bond_numbers = value[3]
self.bond_numbers += self.bond_with_hydrogen
print(self.bond_numbers)
information = []
information.extend(value)
while count < 16:
@ -76,13 +64,6 @@ class Bond(nn.Cell):
self.bond_type_k = information[:self.bond_type_numbers]
break
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.process1(context)
for idx, val in enumerate(context):
if "%FLAG BOND_EQUIL_VALUE" in val:
count = 0
@ -98,7 +79,10 @@ class Bond(nn.Cell):
count += len(value)
self.bond_type_r = information[:self.bond_type_numbers]
break
self.processor(context)
def processor(self, context):
'''processor'''
for idx, val in enumerate(context):
if "%FLAG BONDS_INC_HYDROGEN" in val:
self.h_atom_a = [0] * self.bond_numbers
@ -128,6 +112,7 @@ class Bond(nn.Cell):
for idx, val in enumerate(context):
if "%FLAG BONDS_WITHOUT_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
@ -147,17 +132,3 @@ class Bond(nn.Cell):
self.h_k[i] = self.bond_type_k[tmpi]
self.h_r0[i] = self.bond_type_r[tmpi]
break
def Bond_Energy(self, uint_crd, uint_dr_to_dr_cof):
"""compute bond energy"""
self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.bond_k, self.bond_r0)
self.sigma_of_bond_ene = P.ReduceSum()(self.bond_energy)
return self.sigma_of_bond_ene
def Bond_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute bond force with atom energy"""
self.bfatomenergy = P.BondForceWithAtomEnergy(bond_numbers=self.bond_numbers,
atom_numbers=self.atom_numbers)
frc, atom_energy = self.bfatomenergy(uint_crd, scaler, self.atom_a, self.atom_b, self.bond_k, self.bond_r0)
return frc, atom_energy

View File

@ -12,38 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dihedral class"""
'''Dihedral'''
import math
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Dihedral(nn.Cell):
"""dihedral class"""
class Dihedral:
'''Dihedral'''
def __init__(self, controller):
super(Dihedral, self).__init__()
self.CONSTANT_Pi = 3.1415926535897932
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
self.atom_d = Tensor(np.asarray(self.h_atom_d, np.int32), mstype.int32)
self.pk = Tensor(np.asarray(self.pk, np.float32), mstype.float32)
self.gamc = Tensor(np.asarray(self.gamc, np.float32), mstype.float32)
self.gams = Tensor(np.asarray(self.gams, np.float32), mstype.float32)
self.pn = Tensor(np.asarray(self.pn, np.float32), mstype.float32)
self.ipn = Tensor(np.asarray(self.ipn, np.int32), mstype.int32)
def process1(self, context):
"""process1: read information from amberfile"""
def read_information_from_amberfile(self, file_path):
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
@ -115,15 +100,10 @@ class Dihedral(nn.Cell):
count += len(value)
self.pn_type = information[:self.dihedral_type_numbers]
break
self.processor(context)
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.process1(context)
def processor(self, context):
'''processor'''
self.h_atom_a = [0] * self.dihedral_numbers
self.h_atom_b = [0] * self.dihedral_numbers
self.h_atom_c = [0] * self.dihedral_numbers
@ -204,18 +184,3 @@ class Dihedral(nn.Cell):
for i in range(self.dihedral_numbers):
if self.h_atom_c[i] < 0:
self.h_atom_c[i] *= -1
def Dihedral_Engergy(self, uint_crd, uint_dr_to_dr_cof):
"""compute dihedral energy"""
self.dihedral_energy = P.DihedralEnergy(self.dihedral_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.atom_c, self.atom_d, self.ipn,
self.pk, self.gamc, self.gams, self.pn)
self.sigma_of_dihedral_ene = P.ReduceSum()(self.dihedral_energy)
return self.sigma_of_dihedral_ene
def Dihedral_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute dihedral force and atom energy"""
self.dfae = P.DihedralForceWithAtomEnergy(dihedral_numbers=self.dihedral_numbers)
self.frc, self.ene = self.dfae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.atom_d,
self.ipn, self.pk, self.gamc, self.gams, self.pn)
return self.frc, self.ene

View File

@ -12,18 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Langevin Liujian MD class"""
'''LagevinLiuJian'''
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class Langevin_Liujian:
"""Langevin_Liujian class"""
'''LagevinLiuJian'''
def __init__(self, controller, atom_numbers):
self.atom_numbers = atom_numbers
if controller.amber_parm is not None:
@ -37,29 +33,27 @@ class Langevin_Liujian:
controller.Command_Set["target_temperature"])
self.gamma_ln = 1.0 if "langevin_gamma" not in controller.Command_Set else float(
controller.Command_Set["langevin_gamma"])
self.rand_seed = 0 if "langevin_seed" not in controller.Command_Set else float(
controller.Command_Set["langevin_seed"]) # jiahong0315
self.rand_seed = 1 if "langevin_seed" not in controller.Command_Set else float(
controller.Command_Set["langevin_seed"])
self.max_velocity = 10000.0 if "velocity_max" not in controller.Command_Set else float(
controller.Command_Set["velocity_max"])
assert self.max_velocity > 0
self.is_max_velocity = 0 if "velocity_max" not in controller.Command_Set else 1
print("target temperature is ", self.target_temperature)
print("friction coefficient is ", self.gamma_ln, "ps^-1")
print("random seed is ", self.rand_seed)
self.dt = float(controller.Command_Set["dt"])
self.dt *= self.CONSTANT_TIME_CONVERTION
self.half_dt = 0.5 * self.dt
self.float4_numbers = math.ceil(3.0 * self.atom_numbers / 4.0)
self.rand_state = np.float32(np.zeros([math.ceil(3 * self.atom_numbers / 4.0) * 16,]))
self.gamma_ln = self.gamma_ln / self.CONSTANT_TIME_CONVERTION
self.exp_gamma = math.exp(-1 * self.gamma_ln * self.dt)
self.sqrt_gamma = math.sqrt((1. - self.exp_gamma * self.exp_gamma) * self.target_temperature * self.CONSTANT_kB)
self.h_sqrt_mass = [0] * self.atom_numbers
for i in range(self.atom_numbers):
self.h_sqrt_mass[i] = self.sqrt_gamma * math.sqrt(1. / self.h_mass[i])
self.d_sqrt_mass = Tensor(self.h_sqrt_mass, mstype.float32)
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
@ -81,29 +75,3 @@ class Langevin_Liujian:
for i in range(self.atom_numbers):
self.h_mass[i] = information[i]
break
def MDIterationLeapFrog_Liujian(self, atom_numbers, half_dt, dt, exp_gamma, inverse_mass, sqrt_mass_inverse, vel,
crd, frc, random_frc):
"""compute MDIterationLeapFrog Liujian"""
inverse_mass = inverse_mass.reshape((-1, 1))
sqrt_mass_inverse = sqrt_mass_inverse.reshape((-1, 1))
acc = inverse_mass * frc
vel = vel + dt * acc
crd = crd + half_dt * vel
vel = exp_gamma * vel + sqrt_mass_inverse * random_frc
crd = crd + half_dt * vel
frc = Tensor(np.zeros((atom_numbers, 3)), mstype.float32)
return vel, crd, frc, acc
def MD_Iteration_Leap_Frog(self, d_mass_inverse, vel_in, crd_in, frc_in):
"""MD_Iteration_Leap_Frog"""
np.random.seed(int(self.rand_seed))
self.rand_force = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
# self.rand_force = Tensor(np.random.randn(self.atom_numbers, 3), mstype.float32)
vel, crd, frc, acc = self.MDIterationLeapFrog_Liujian(atom_numbers=self.atom_numbers, half_dt=self.half_dt,
dt=self.dt, exp_gamma=self.exp_gamma,
inverse_mass=d_mass_inverse,
sqrt_mass_inverse=self.d_sqrt_mass,
vel=vel_in, crd=crd_in,
frc=frc_in, random_frc=self.rand_force)
return vel, crd, frc, acc

View File

@ -12,30 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lennard jones"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Lennard_Jones_Information(nn.Cell):
"""class Lennard Jones Information"""
'''Lennard Jones'''
class Lennard_Jones_Information:
'''Lennard Jones'''
def __init__(self, controller):
super(Lennard_Jones_Information, self).__init__()
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_LJ_type = Tensor(np.asarray(self.atom_LJ_type, dtype=np.int32), mstype.int32)
self.LJ_A = Tensor(np.asarray(self.LJ_A, dtype=np.float32), mstype.float32)
self.LJ_B = Tensor(np.asarray(self.LJ_B, dtype=np.float32), mstype.float32)
self.LJ_energy_sum = 0
self.LJ_energy = 0
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
@ -48,8 +34,8 @@ class Lennard_Jones_Information(nn.Cell):
value = list(map(int, context[start_idx].strip().split()))
self.atom_numbers = value[0]
self.atom_type_numbers = value[1]
self.pair_type_numbers = int(self.atom_type_numbers * (self.atom_type_numbers + 1) / 2)
print(self.pair_type_numbers)
self.pair_type_numbers = int(
self.atom_type_numbers * (self.atom_type_numbers + 1) / 2) # TODO 这个地方有问题啊
break
self.atom_LJ_type = [0] * self.atom_numbers
for idx, val in enumerate(context):
@ -102,21 +88,3 @@ class Lennard_Jones_Information(nn.Cell):
for i in range(self.pair_type_numbers):
self.LJ_B[i] = 6.0 * information[i]
break
def LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, cutoff_square):
"""compute LJ energy"""
uint_crd, LJtype, charge = uint_crd_with_LJ
self.LJ_energy = P.LJEnergy(self.atom_numbers, cutoff_square) \
(uint_crd, LJtype, charge, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, self.LJ_A, self.LJ_B)
self.LJ_energy_sum = P.ReduceSum()(self.LJ_energy)
return self.LJ_energy_sum
def LJ_Force_With_PME_Direct_Force(self, atom_numbers, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_number, nl_serial,
cutoff, beta):
"""compute LJ force with PME direct force"""
assert atom_numbers == self.atom_numbers
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd_f, LJtype, charge = uint_crd_with_LJ
self.ljfd = P.LJForceWithPMEDirectForce(atom_numbers, cutoff, beta)
frc = self.ljfd(uint_crd_f, LJtype, charge, uint_dr_to_dr_cof, nl_number, nl_serial, self.LJ_A, self.LJ_B)
return frc

View File

@ -12,36 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""md information"""
'''MD Information'''
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class md_information(nn.Cell):
"""class md information"""
class md_information:
'''MD Information'''
def __init__(self, controller):
super(md_information, self).__init__()
CONSTANT_TIME_CONVERTION = 20.455
CONSTANT_UINT_MAX_FLOAT = 4294967296.0
self.md_task = controller.md_task
self.mode = 0 if "mode" not in controller.Command_Set else int(controller.Command_Set["mode"])
self.dt = 0.001 * CONSTANT_TIME_CONVERTION if "dt" not in controller.Command_Set \
else float(controller.Command_Set["dt"]) * CONSTANT_TIME_CONVERTION
self.skin = 2.0 if "skin" not in controller.Command_Set \
else float(controller.Command_Set["skin"])
self.dt = 0.001 * CONSTANT_TIME_CONVERTION if "dt" not in controller.Command_Set else float(
controller.Command_Set["dt"]) * CONSTANT_TIME_CONVERTION
self.skin = 2.0 if "skin" not in controller.Command_Set else float(controller.Command_Set["skin"])
self.trans_vec = [self.skin, self.skin, self.skin]
self.trans_vec_minus = -1 * self.trans_vec
self.step_limit = 1000 if "step_limit" not in controller.Command_Set else int(
controller.Command_Set["step_limit"])
self.netfrc = 0 if "net_force" not in controller.Command_Set else int(controller.Command_Set["net_force"])
self.ntwx = 1000 if "write_information_interval" not in controller.Command_Set else \
int(controller.Command_Set["write_information_interval"])
self.ntce = self.step_limit + 1 if "calculate_energy_interval" not in controller.Command_Set else \
int(controller.Command_Set["calculate_energy_interval"])
self.ntwx = 1000 if "write_information_interval" not in controller.Command_Set else int(
controller.Command_Set["write_information_interval"])
self.ntce = self.step_limit + 1 if "calculate_energy_interval" not in controller.Command_Set else int(
controller.Command_Set["calculate_energy_interval"])
self.atom_numbers = 0
self.residue_numbers = 0
self.density = 0.0
@ -51,7 +44,6 @@ class md_information(nn.Cell):
self.h_mass = []
self.h_mass_inverse = []
self.h_charge = []
self.steps = 0
if controller.amber_parm is not None:
self.read_basic_system_information_from_amber_file(controller.amber_parm)
@ -67,23 +59,13 @@ class md_information(nn.Cell):
self.uint_dr_to_dr_cof = [1.0 / self.crd_to_uint_crd_cof[0], 1.0 / self.crd_to_uint_crd_cof[1],
1.0 / self.crd_to_uint_crd_cof[2]]
self.density *= 1e24 / 6.023e23 / (self.box_length[0] * self.box_length[1] * self.box_length[2])
self.frc = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
self.crd = Tensor(np.array(self.coordinate, dtype=np.float32).reshape((self.atom_numbers, 3)), mstype.float32)
self.crd_n = np.array(self.coordinate).reshape([self.atom_numbers, 3])
self.crd_old = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32)
self.uint_crd = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32)
self.charge = Tensor(self.h_charge, mstype.float32)
self.crd_to_uint_crd_cof_n = np.array(self.crd_to_uint_crd_cof)
self.crd_to_uint_crd_cof = Tensor(self.crd_to_uint_crd_cof, mstype.float32)
self.uint_dr_to_dr_cof = Tensor(self.uint_dr_to_dr_cof, mstype.float32)
self.uint_crd_with_LJ = None
self.d_mass_inverse = Tensor(self.h_mass_inverse, mstype.float32)
self.d_res_start = Tensor(self.h_res_start, mstype.int32)
self.d_res_end = Tensor(self.h_res_end, mstype.int32)
self.d_mass = Tensor(self.h_mass, mstype.float32)
def process1(self, context):
"""process1: read basic system information from amber file"""
self.velocity = np.reshape(np.asarray(self.velocity, np.float32), [self.atom_numbers, 3])
def read_basic_system_information_from_amber_file(self, path):
'''read amber file'''
file = open(path, 'r')
context = file.readlines()
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
@ -95,16 +77,10 @@ class md_information(nn.Cell):
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
count += len(value)
self.residue_numbers = list(map(int, context[start_idx].strip().split()))[10 - (count - 10)]
self.residue_numbers = list(map(int, context[start_idx].strip().split()))[
10 - (count - 10)] # may exist bug
break
def read_basic_system_information_from_amber_file(self, path):
"""read basic system information from amber file"""
file = open(path, 'r')
context = file.readlines()
file.close()
self.process1(context)
if self.residue_numbers != 0 and self.atom_numbers != 0:
for idx, val in enumerate(context):
if "%FLAG RESIDUE_POINTER" in val:
@ -124,42 +100,45 @@ class md_information(nn.Cell):
self.h_res_start.append(self.lin_serial[-1] - 1)
self.h_res_end.append(self.atom_numbers + 1 - 1)
break
self.processor(context)
for idx, val in enumerate(context):
if "%FLAG MASS" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_mass.extend(value)
count += len(value)
for i in range(self.atom_numbers):
if self.h_mass[i] == 0:
self.h_mass_inverse.append(0.0)
else:
self.h_mass_inverse.append(1.0 / self.h_mass[i])
self.density += self.h_mass[i]
break
for idx, val in enumerate(context):
if "%FLAG CHARGE" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_charge.extend(value)
count += len(value)
break
def processor(self, context):
'''processor'''
for idx, val in enumerate(context):
if "%FLAG MASS" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_mass.extend(value)
count += len(value)
for i in range(self.atom_numbers):
if self.h_mass[i] == 0:
self.h_mass_inverse.append(0.0)
else:
self.h_mass_inverse.append(1.0 / self.h_mass[i])
self.density += self.h_mass[i]
break
for idx, val in enumerate(context):
if "%FLAG CHARGE" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_charge.extend(value)
count += len(value)
break
def read_basic_system_information_from_rst7(self, path, irest):
"""read basic system information from rst7"""
'''read rst7 file'''
file = open(path, 'r')
context = file.readlines()
file.close()
@ -191,22 +170,4 @@ class md_information(nn.Cell):
self.coordinate = information[: 3 * self.atom_numbers]
self.velocity = [0.0] * (3 * self.atom_numbers)
self.box_length = information[3 * self.atom_numbers:3 * self.atom_numbers + 3]
self.vel = Tensor(self.velocity, mstype.float32).reshape((self.atom_numbers, 3))
self.acc = Tensor(np.zeros((self.atom_numbers, 3), dtype=np.float32), mstype.float32)
def MD_Information_Crd_To_Uint_Crd(self):
"""transform the crd to uint crd"""
uint_crd = self.crd.asnumpy() * (0.5 * self.crd_to_uint_crd_cof.asnumpy()) * 2
self.uint_crd = Tensor(uint_crd, mstype.uint32)
return self.uint_crd
def Centerize(self):
return
def MD_Information_Temperature(self):
"""compute temperature"""
self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers)
self.res_ek_energy = self.mdtemp(self.d_res_start, self.d_res_end, self.vel, self.d_mass)
self.d_temperature = P.ReduceSum()(self.res_ek_energy)
return self.d_temperature
print("system size is ", self.box_length[0], self.box_length[1], self.box_length[2])

View File

@ -12,20 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""nb14"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class NON_BOND_14(nn.Cell):
"""class Non bond 14"""
'''NON BOND'''
class NON_BOND_14:
'''NON BOND'''
def __init__(self, controller, dihedral, atom_numbers):
super(NON_BOND_14, self).__init__()
self.dihedral_with_hydrogen = dihedral.dihedral_with_hydrogen
self.dihedral_numbers = dihedral.dihedral_numbers
self.dihedral_type_numbers = dihedral.dihedral_type_numbers
@ -34,14 +24,20 @@ class NON_BOND_14(nn.Cell):
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.h_atom_a = self.h_atom_a[:self.nb14_numbers]
self.h_atom_b = self.h_atom_b[:self.nb14_numbers]
self.h_lj_scale_factor = self.h_lj_scale_factor[:self.nb14_numbers]
self.h_cf_scale_factor = self.h_cf_scale_factor[:self.nb14_numbers]
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.lj_scale_factor = Tensor(np.asarray(self.h_lj_scale_factor, np.float32), mstype.float32)
self.cf_scale_factor = Tensor(np.asarray(self.h_cf_scale_factor, np.float32), mstype.float32)
def read_information_from_amberfile(self, file_path):
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.cf_scale_type = [0] * self.dihedral_type_numbers
self.lj_scale_type = [0] * self.dihedral_type_numbers
def process1(self, context):
"""process1: read information from amberfile"""
for idx, val in enumerate(context):
if "%FLAG SCEE_SCALE_FACTOR" in val:
count = 0
@ -73,16 +69,10 @@ class NON_BOND_14(nn.Cell):
count += len(value)
self.lj_scale_type = information[:self.dihedral_type_numbers]
break
self.processor(context)
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.cf_scale_type = [0] * self.dihedral_type_numbers
self.lj_scale_type = [0] * self.dihedral_type_numbers
self.process1(context)
def processor(self, context):
'''processor'''
self.h_atom_a = [0] * self.dihedral_numbers
self.h_atom_b = [0] * self.dihedral_numbers
self.h_lj_scale_factor = [0] * self.dihedral_numbers
@ -154,42 +144,3 @@ class NON_BOND_14(nn.Cell):
break
self.nb14_numbers = nb14_numbers
def Non_Bond_14_LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
"""compute Non bond 14 LJ energy"""
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd, LJtype, charge = uint_crd_with_LJ
self.LJ_energy = P.Dihedral14LJEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.lj_scale_factor,
LJ_A, LJ_B)
self.nb14_lj_energy_sum = P.ReduceSum()(self.LJ_energy)
return self.nb14_lj_energy_sum
def Non_Bond_14_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof):
"""compute Non bond 14 CF energy"""
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd, LJtype, charge = uint_crd_with_LJ
self.CF_energy = P.Dihedral14CFEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.cf_scale_factor)
self.nb14_cf_energy_sum = P.ReduceSum()(self.CF_energy)
return self.nb14_cf_energy_sum
def Non_Bond_14_LJ_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
"""compute Non bond 14 LJ and CF energy"""
assert isinstance(uint_crd_with_LJ, tuple)
self.nb14_lj_energy_sum = self.Non_Bond_14_LJ_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B)
self.nb14_cf_energy_sum = self.Non_Bond_14_CF_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof)
return self.nb14_lj_energy_sum, self.nb14_cf_energy_sum
def Non_Bond_14_LJ_CF_Force_With_Atom_Energy(self, uint_crd_with_LJ, boxlength, LJ_A, LJ_B):
"""compute Non bond 14 LJ CF force and atom energy"""
self.d14lj = P.Dihedral14LJCFForceWithAtomEnergy(nb14_numbers=self.nb14_numbers, atom_numbers=self.atom_numbers)
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd_f, LJtype, charge = uint_crd_with_LJ
self.frc, self.atom_ene = self.d14lj(uint_crd_f, LJtype, charge, boxlength, self.atom_a, self.atom_b,
self.lj_scale_factor, self.cf_scale_factor, LJ_A, LJ_B)
return self.frc, self.atom_ene

View File

@ -12,25 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""neighbour list"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class nb_infomation(nn.Cell):
"""neighbour list"""
'''Neighbor List'''
class neighbor_list:
'''Neighbor List'''
def __init__(self, controller, atom_numbers, box_length):
super(nb_infomation, self).__init__()
self.refresh_interval = 20 if "neighbor_list_refresh_interval" not in controller.Command_Set else \
int(controller.Command_Set["neighbor_list_refresh_interval"])
self.max_atom_in_grid_numbers = 64 if "max_atom_in_grid_numbers" not in controller.Command_Set else \
int(controller.Command_Set["max_atom_in_grid_numbers"])
self.max_neighbor_numbers = 800 if "max_neighbor_numbers" not in controller.Command_Set else \
int(controller.Command_Set["max_neighbor_numbers"])
self.refresh_interval = 20 if "neighbor_list_refresh_interval" not in controller.Command_Set else int(
controller.Command_Set["neighbor_list_refresh_interval"])
self.max_atom_in_grid_numbers = 64 if "max_atom_in_grid_numbers" not in controller.Command_Set else int(
controller.Command_Set["max_atom_in_grid_numbers"])
self.max_neighbor_numbers = 800 if "max_neighbor_numbers" not in controller.Command_Set else int(
controller.Command_Set["max_neighbor_numbers"])
self.skin = 2.0 if "skin" not in controller.Command_Set else float(controller.Command_Set["skin"])
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
self.cutoff_square = self.cutoff * self.cutoff
@ -47,24 +38,10 @@ class nb_infomation(nn.Cell):
self.Initial_Neighbor_Grid()
self.not_first_time = 0
self.refresh_count = 0
self.atom_numbers_in_grid_bucket = Tensor(np.asarray(self.atom_numbers_in_grid_bucket, np.int32), mstype.int32)
self.bucket = Tensor(
np.asarray(self.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]), mstype.int32)
self.grid_N = Tensor(np.asarray(self.grid_N, np.int32), mstype.int32)
self.grid_length_inverse = Tensor(np.asarray(self.grid_length_inverse, np.float32), mstype.float32)
self.atom_in_grid_serial = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
self.pointer = Tensor(np.asarray(self.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32)
self.nl_atom_numbers = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
self.nl_atom_serial = Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32)
self.excluded_list_start = Tensor(np.asarray(self.excluded_list_start, np.int32), mstype.int32)
self.excluded_list = Tensor(np.asarray(self.excluded_list, np.int32), mstype.int32)
self.excluded_numbers = Tensor(np.asarray(self.excluded_numbers, np.int32), mstype.int32)
self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
self.refresh_count = [0]
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
'''read amber file'''
file = open(file_path, 'r')
context = file.readlines()
file.close()
@ -85,6 +62,7 @@ class nb_infomation(nn.Cell):
information.extend(value)
count += len(value)
self.excluded_atom_numbers = information[10]
print("excluded atom numbers ", self.excluded_atom_numbers)
break
for idx, val in enumerate(context):
if "%FLAG NUMBER_EXCLUDED_ATOMS" in val:
@ -125,37 +103,22 @@ class nb_infomation(nn.Cell):
count = 0
for i in range(self.atom_numbers):
tmp_list = []
for _ in range(self.excluded_numbers[i]):
if self.excluded_numbers[i] == 1:
tmp_list.append(information[count] - 1)
if information[count] == 0:
self.excluded_numbers[i] = 0
count += 1
tmp_list = sorted(tmp_list)
else:
for _ in range(self.excluded_numbers[i]):
tmp_list.append(information[count] - 1)
count += 1
tmp_list = sorted(tmp_list)
self.excluded_list.extend(tmp_list)
break
def fun(self, Nx, Ny, Nz, l, m, temp_grid_serial, count):
"""fun to replace the for"""
for n in range(-2, 3):
xx = Nx + l
if xx < 0:
xx = xx + self.Nx
elif xx >= self.Nx:
xx = xx - self.Nx
yy = Ny + m
if yy < 0:
yy = yy + self.Ny
elif yy >= self.Ny:
yy = yy - self.Ny
zz = Nz + n
if zz < 0:
zz = zz + self.Nz
elif zz >= self.Nz:
zz = zz - self.Nz
temp_grid_serial[count] = zz * self.Nxy + yy * self.Nx + xx
count += 1
return temp_grid_serial, count
def Initial_Neighbor_Grid(self):
"""initial neighbour grid"""
'''init neighbor grid'''
half_cutoff = self.half_cutoff_with_skin
self.Nx = int(self.box_length[0] / half_cutoff)
self.Ny = int(self.box_length[1] / half_cutoff)
@ -177,31 +140,23 @@ class nb_infomation(nn.Cell):
count = 0
for l in range(-2, 3):
for m in range(-2, 3):
temp_grid_serial, count = self.fun(Nx, Ny, Nz, l, m, temp_grid_serial, count)
for n in range(-2, 3):
xx = Nx + l
if xx < 0:
xx = xx + self.Nx
elif xx >= self.Nx:
xx = xx - self.Nx
yy = Ny + m
if yy < 0:
yy = yy + self.Ny
elif yy >= self.Ny:
yy = yy - self.Ny
zz = Nz + n
if zz < 0:
zz = zz + self.Nz
elif zz >= self.Nz:
zz = zz - self.Nz
temp_grid_serial[count] = zz * self.Nxy + yy * self.Nx + xx
count += 1
temp_grid_serial = sorted(temp_grid_serial)
self.pointer.extend(temp_grid_serial)
def NeighborListUpdate(self, crd, old_crd, uint_crd, crd_to_uint_crd_cof, uint_dr_to_dr_cof, box_length,
not_first_time=0):
"""NeighborList Update"""
self.not_first_time = not_first_time
self.neighbor_list_update = P.NeighborListUpdate(grid_numbers=self.grid_numbers, atom_numbers=self.atom_numbers,
refresh_count=self.refresh_count,
not_first_time=self.not_first_time,
Nxy=self.Nxy, excluded_atom_numbers=self.excluded_atom_numbers,
cutoff_square=self.cutoff_square,
half_skin_square=self.half_skin_square,
cutoff_with_skin=self.cutoff_with_skin,
half_cutoff_with_skin=self.half_cutoff_with_skin,
cutoff_with_skin_square=self.cutoff_with_skin_square,
refresh_interval=self.refresh_interval, cutoff=self.cutoff,
skin=self.skin,
max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
max_neighbor_numbers=self.max_neighbor_numbers)
res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket, self.bucket, crd, box_length, self.grid_N,
self.grid_length_inverse, self.atom_in_grid_serial, old_crd,
crd_to_uint_crd_cof, uint_crd, self.pointer, self.nl_atom_numbers,
self.nl_atom_serial, uint_dr_to_dr_cof, self.excluded_list_start,
self.excluded_list, self.excluded_numbers, self.need_refresh_flag)
return res

View File

@ -12,20 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PME"""
'''PME'''
import math
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Particle_Mesh_Ewald(nn.Cell):
"""class Particle_Mesh_Ewald"""
class Particle_Mesh_Ewald():
'''PME'''
def __init__(self, controller, md_info):
super(Particle_Mesh_Ewald, self).__init__()
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
self.tolerance = 0.00001 if "PME_Direct_Tolerance" not in controller.Command_Set else float(
controller.Command_Set["PME_Direct_Tolerance"])
@ -43,12 +36,9 @@ class Particle_Mesh_Ewald(nn.Cell):
self.fftz = self.Get_Fft_Patameter(self.box_length[2])
self.beta = self.Get_Beta(self.cutoff, self.tolerance)
self.box_length = Tensor(np.asarray(self.box_length, np.float32), mstype.float32)
print("========== ", self.fftx, self.ffty, self.fftz, self.tolerance, self.beta)
def Get_Beta(self, cutoff, tolerance):
"""get beta"""
'''GET BETA'''
high = 1.0
ihigh = 1
while 1:
@ -69,7 +59,7 @@ class Particle_Mesh_Ewald(nn.Cell):
return beta
def Check_2357_Factor(self, number):
"""check 2357 factor"""
'''CHECK FACTOR'''
while number > 0:
if number == 1:
return 1
@ -101,7 +91,7 @@ class Particle_Mesh_Ewald(nn.Cell):
return 0
def Get_Fft_Patameter(self, length):
"""get fft parameter"""
'''GET FFT PARAMETER'''
tempi = math.ceil(length + 3) >> 2 << 2
if 60 <= tempi <= 68:
tempi = 64
@ -117,31 +107,3 @@ class Particle_Mesh_Ewald(nn.Cell):
if self.Check_2357_Factor(tempi):
return tempi
tempi += 4
def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start,
excluded_list, excluded_numbers, excluded_atom_numbers):
"""PME_Energy"""
self.pmee = P.PMEEnergy(self.atom_numbers, excluded_atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \
self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof,
excluded_list_start, excluded_list, excluded_numbers)
return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy
def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list,
excluded_numbers, excluded_atom_numbers):
"""PME Excluded Force"""
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, excluded_numbers=excluded_atom_numbers,
beta=self.beta)
self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers)
return self.frc
def PME_Reciprocal_Force(self, uint_crd, charge):
"""PME reciprocal force"""
self.pmerf = P.PMEReciprocalForce(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
self.frc = self.pmerf(self.box_length, uint_crd, charge)
return self.frc
def Energy_Device_To_Host(self):
"""Energy_Device_To_Host"""
self.ee_ene = self.reciprocal_energy + self.self_energy + self.direct_energy + self.correction_energy
return self.ee_ene

View File

@ -0,0 +1,439 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Simulation'''
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import nn
from mindspore.common.parameter import Parameter
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from src.angle import Angle
from src.bond import Bond
from src.dihedral import Dihedral
from src.langevin_liujian_md import Langevin_Liujian
from src.lennard_jones import Lennard_Jones_Information
from src.md_information import md_information
from src.nb14 import NON_BOND_14
from src.neighbor_list import neighbor_list
from src.particle_mesh_ewald import Particle_Mesh_Ewald
class controller:
'''controller'''
def __init__(self, args_opt):
self.input_file = args_opt.i
self.initial_coordinates_file = args_opt.c
self.amber_parm = args_opt.amber_parm
self.restrt = args_opt.r
self.mdcrd = args_opt.x
self.mdout = args_opt.o
self.mdbox = args_opt.box
self.Command_Set = {}
self.md_task = None
self.commands_from_in_file()
def commands_from_in_file(self):
'''command from in file'''
file = open(self.input_file, 'r')
context = file.readlines()
file.close()
self.md_task = context[0].strip()
for val in context:
if "=" in val:
assert len(val.strip().split("=")) == 2
flag, value = val.strip().split("=")
value = value.replace(",", '')
flag = flag.replace(" ", "")
if flag not in self.Command_Set:
self.Command_Set[flag] = value
else:
print("ERROR COMMAND FILE")
class Simulation(nn.Cell):
'''simulation'''
def __init__(self, args_opt):
super(Simulation, self).__init__()
self.control = controller(args_opt)
self.md_info = md_information(self.control)
self.bond = Bond(self.control, self.md_info)
self.angle = Angle(self.control)
self.dihedral = Dihedral(self.control)
self.nb14 = NON_BOND_14(self.control, self.dihedral, self.md_info.atom_numbers)
self.nb_info = neighbor_list(self.control, self.md_info.atom_numbers, self.md_info.box_length)
self.LJ_info = Lennard_Jones_Information(self.control)
self.liujian_info = Langevin_Liujian(self.control, self.md_info.atom_numbers)
self.pme_method = Particle_Mesh_Ewald(self.control, self.md_info)
self.bond_energy_sum = Tensor(0, mstype.int32)
self.angle_energy_sum = Tensor(0, mstype.int32)
self.dihedral_energy_sum = Tensor(0, mstype.int32)
self.nb14_lj_energy_sum = Tensor(0, mstype.int32)
self.nb14_cf_energy_sum = Tensor(0, mstype.int32)
self.lj_energy_sum = Tensor(0, mstype.int32)
self.ee_ene = Tensor(0, mstype.int32)
self.total_energy = Tensor(0, mstype.int32)
# Init scalar
self.ntwx = self.md_info.ntwx
self.atom_numbers = self.md_info.atom_numbers
self.residue_numbers = self.md_info.residue_numbers
self.bond_numbers = self.bond.bond_numbers
self.angle_numbers = self.angle.angle_numbers
self.dihedral_numbers = self.dihedral.dihedral_numbers
self.nb14_numbers = self.nb14.nb14_numbers
self.Nxy = self.nb_info.Nxy
self.grid_numbers = self.nb_info.grid_numbers
self.max_atom_in_grid_numbers = self.nb_info.max_atom_in_grid_numbers
self.max_neighbor_numbers = self.nb_info.max_neighbor_numbers
self.excluded_atom_numbers = self.nb_info.excluded_atom_numbers
self.refresh_count = Parameter(Tensor(self.nb_info.refresh_count, mstype.int32), requires_grad=False)
self.refresh_interval = self.nb_info.refresh_interval
self.skin = self.nb_info.skin
self.cutoff = self.nb_info.cutoff
self.cutoff_square = self.nb_info.cutoff_square
self.cutoff_with_skin = self.nb_info.cutoff_with_skin
self.half_cutoff_with_skin = self.nb_info.half_cutoff_with_skin
self.cutoff_with_skin_square = self.nb_info.cutoff_with_skin_square
self.half_skin_square = self.nb_info.half_skin_square
self.beta = self.pme_method.beta
self.fftx = self.pme_method.fftx
self.ffty = self.pme_method.ffty
self.fftz = self.pme_method.fftz
self.random_seed = self.liujian_info.rand_seed
self.dt = self.liujian_info.dt
self.half_dt = self.liujian_info.half_dt
self.exp_gamma = self.liujian_info.exp_gamma
self.init_Tensor()
self.op_define()
def init_Tensor(self):
'''init tensor'''
self.crd = Parameter(
Tensor(np.float32(np.asarray(self.md_info.coordinate).reshape([self.atom_numbers, 3])), mstype.float32),
requires_grad=False)
self.crd_to_uint_crd_cof = Tensor(np.asarray(self.md_info.crd_to_uint_crd_cof, np.float32), mstype.float32)
self.uint_dr_to_dr_cof = Parameter(
Tensor(np.asarray(self.md_info.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False)
self.box_length = Tensor(self.md_info.box_length, mstype.float32)
self.charge = Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32)
self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
requires_grad=False)
self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32),
requires_grad=False)
self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32)
self.res_start = Tensor(self.md_info.h_res_start, mstype.int32)
self.res_end = Tensor(self.md_info.h_res_end, mstype.int32)
self.mass = Tensor(self.md_info.h_mass, mstype.float32)
self.velocity = Parameter(Tensor(self.md_info.velocity, mstype.float32), requires_grad=False)
self.acc = Parameter(Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32), requires_grad=False)
self.bond_atom_a = Tensor(np.asarray(self.bond.h_atom_a, np.int32), mstype.int32)
self.bond_atom_b = Tensor(np.asarray(self.bond.h_atom_b, np.int32), mstype.int32)
self.bond_k = Tensor(np.asarray(self.bond.h_k, np.float32), mstype.float32)
self.bond_r0 = Tensor(np.asarray(self.bond.h_r0, np.float32), mstype.float32)
self.angle_atom_a = Tensor(np.asarray(self.angle.h_atom_a, np.int32), mstype.int32)
self.angle_atom_b = Tensor(np.asarray(self.angle.h_atom_b, np.int32), mstype.int32)
self.angle_atom_c = Tensor(np.asarray(self.angle.h_atom_c, np.int32), mstype.int32)
self.angle_k = Tensor(np.asarray(self.angle.h_angle_k, np.float32), mstype.float32)
self.angle_theta0 = Tensor(np.asarray(self.angle.h_angle_theta0, np.float32), mstype.float32)
self.dihedral_atom_a = Tensor(np.asarray(self.dihedral.h_atom_a, np.int32), mstype.int32)
self.dihedral_atom_b = Tensor(np.asarray(self.dihedral.h_atom_b, np.int32), mstype.int32)
self.dihedral_atom_c = Tensor(np.asarray(self.dihedral.h_atom_c, np.int32), mstype.int32)
self.dihedral_atom_d = Tensor(np.asarray(self.dihedral.h_atom_d, np.int32), mstype.int32)
self.pk = Tensor(np.asarray(self.dihedral.pk, np.float32), mstype.float32)
self.gamc = Tensor(np.asarray(self.dihedral.gamc, np.float32), mstype.float32)
self.gams = Tensor(np.asarray(self.dihedral.gams, np.float32), mstype.float32)
self.pn = Tensor(np.asarray(self.dihedral.pn, np.float32), mstype.float32)
self.ipn = Tensor(np.asarray(self.dihedral.ipn, np.int32), mstype.int32)
self.nb14_atom_a = Tensor(np.asarray(self.nb14.h_atom_a, np.int32), mstype.int32)
self.nb14_atom_b = Tensor(np.asarray(self.nb14.h_atom_b, np.int32), mstype.int32)
self.lj_scale_factor = Tensor(np.asarray(self.nb14.h_lj_scale_factor, np.float32), mstype.float32)
self.cf_scale_factor = Tensor(np.asarray(self.nb14.h_cf_scale_factor, np.float32), mstype.float32)
self.grid_N = Tensor(self.nb_info.grid_N, mstype.int32)
self.grid_length_inverse = Tensor(self.nb_info.grid_length_inverse, mstype.float32)
self.bucket = Parameter(Tensor(
np.asarray(self.nb_info.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]),
mstype.int32), requires_grad=False)
self.atom_numbers_in_grid_bucket = Parameter(Tensor(self.nb_info.atom_numbers_in_grid_bucket, mstype.int32),
requires_grad=False)
self.atom_in_grid_serial = Parameter(Tensor(np.zeros([self.nb_info.atom_numbers,], np.int32), mstype.int32),
requires_grad=False)
self.pointer = Parameter(
Tensor(np.asarray(self.nb_info.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32),
requires_grad=False)
self.nl_atom_numbers = Parameter(Tensor(np.zeros([self.atom_numbers,], np.int32), mstype.int32),
requires_grad=False)
self.nl_atom_serial = Parameter(
Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32),
requires_grad=False)
self.excluded_list_start = Tensor(np.asarray(self.nb_info.excluded_list_start, np.int32), mstype.int32)
self.excluded_list = Tensor(np.asarray(self.nb_info.excluded_list, np.int32), mstype.int32)
self.excluded_numbers = Tensor(np.asarray(self.nb_info.excluded_numbers, np.int32), mstype.int32)
self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
self.atom_LJ_type = Tensor(np.asarray(self.LJ_info.atom_LJ_type, dtype=np.int32), mstype.int32)
self.LJ_A = Tensor(np.asarray(self.LJ_info.LJ_A, dtype=np.float32), mstype.float32)
self.LJ_B = Tensor(np.asarray(self.LJ_info.LJ_B, dtype=np.float32), mstype.float32)
self.sqrt_mass = Tensor(self.liujian_info.h_sqrt_mass, mstype.float32)
self.rand_state = Parameter(Tensor(self.liujian_info.rand_state, mstype.float32))
self.zero_fp_tensor = Tensor(np.asarray([0,], np.float32))
def op_define(self):
'''op define'''
self.crd_to_uint_crd = P.CrdToUintCrd(self.atom_numbers)
self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers)
self.setup_random_state = P.MDIterationSetupRandState(self.atom_numbers, self.random_seed)
self.bond_force_with_atom_energy = P.BondForceWithAtomEnergy(bond_numbers=self.bond_numbers,
atom_numbers=self.atom_numbers)
self.angle_force_with_atom_energy = P.AngleForceWithAtomEnergy(angle_numbers=self.angle_numbers)
self.dihedral_force_with_atom_energy = P.DihedralForceWithAtomEnergy(dihedral_numbers=self.dihedral_numbers)
self.nb14_force_with_atom_energy = P.Dihedral14LJCFForceWithAtomEnergy(nb14_numbers=self.nb14_numbers,
atom_numbers=self.atom_numbers)
self.lj_force_pme_direct_force = P.LJForceWithPMEDirectForce(self.atom_numbers, self.cutoff, self.beta)
self.pme_excluded_force = P.PMEExcludedForce(atom_numbers=self.atom_numbers,
excluded_numbers=self.excluded_atom_numbers, beta=self.beta)
self.pme_reciprocal_force = P.PMEReciprocalForce(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz,
self.md_info.box_length[0], self.md_info.box_length[1],
self.md_info.box_length[2])
self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)
self.angle_energy = P.AngleEnergy(self.angle_numbers)
self.dihedral_energy = P.DihedralEnergy(self.dihedral_numbers)
self.nb14_lj_energy = P.Dihedral14LJEnergy(self.nb14_numbers, self.atom_numbers)
self.nb14_cf_energy = P.Dihedral14CFEnergy(self.nb14_numbers, self.atom_numbers)
self.lj_energy = P.LJEnergy(self.atom_numbers, self.cutoff_square)
self.pme_energy = P.PMEEnergy(self.atom_numbers, self.excluded_atom_numbers, self.beta, self.fftx, self.ffty,
self.fftz, self.md_info.box_length[0], self.md_info.box_length[1],
self.md_info.box_length[2])
self.md_iteration_leap_frog_liujian = P.MDIterationLeapFrogLiujian(self.atom_numbers, self.half_dt, self.dt,
self.exp_gamma)
self.neighbor_list_update_init = P.NeighborListUpdate(grid_numbers=self.grid_numbers,
atom_numbers=self.atom_numbers, not_first_time=0,
Nxy=self.Nxy,
excluded_atom_numbers=self.excluded_atom_numbers,
cutoff_square=self.cutoff_square,
half_skin_square=self.half_skin_square,
cutoff_with_skin=self.cutoff_with_skin,
half_cutoff_with_skin=self.half_cutoff_with_skin,
cutoff_with_skin_square=self.cutoff_with_skin_square,
refresh_interval=self.refresh_interval,
cutoff=self.cutoff, skin=self.skin,
max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
max_neighbor_numbers=self.max_neighbor_numbers)
self.neighbor_list_update = P.NeighborListUpdate(grid_numbers=self.grid_numbers, atom_numbers=self.atom_numbers,
not_first_time=1, Nxy=self.Nxy,
excluded_atom_numbers=self.excluded_atom_numbers,
cutoff_square=self.cutoff_square,
half_skin_square=self.half_skin_square,
cutoff_with_skin=self.cutoff_with_skin,
half_cutoff_with_skin=self.half_cutoff_with_skin,
cutoff_with_skin_square=self.cutoff_with_skin_square,
refresh_interval=self.refresh_interval, cutoff=self.cutoff,
skin=self.skin,
max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
max_neighbor_numbers=self.max_neighbor_numbers)
self.random_force = Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32)
def Simulation_Beforce_Caculate_Force(self):
'''simulation before calculate force'''
crd_to_uint_crd_cof = 0.5 * self.crd_to_uint_crd_cof
uint_crd = self.crd_to_uint_crd(crd_to_uint_crd_cof, self.crd)
return uint_crd
def Simulation_Caculate_Force(self, uint_crd, scaler, nl_atom_numbers, nl_atom_serial):
'''simulation calculate force'''
bond_force, _ = self.bond_force_with_atom_energy(uint_crd, scaler, self.bond_atom_a,
self.bond_atom_b, self.bond_k, self.bond_r0)
angle_force, _ = self.angle_force_with_atom_energy(uint_crd, scaler, self.angle_atom_a,
self.angle_atom_b, self.angle_atom_c,
self.angle_k, self.angle_theta0)
dihedral_force, _ = self.dihedral_force_with_atom_energy(uint_crd, scaler,
self.dihedral_atom_a,
self.dihedral_atom_b,
self.dihedral_atom_c,
self.dihedral_atom_d, self.ipn,
self.pk, self.gamc, self.gams,
self.pn)
nb14_force, _ = self.nb14_force_with_atom_energy(uint_crd, self.atom_LJ_type, self.charge,
scaler, self.nb14_atom_a, self.nb14_atom_b,
self.lj_scale_factor, self.cf_scale_factor,
self.LJ_A, self.LJ_B)
lj_force = self.lj_force_pme_direct_force(uint_crd, self.atom_LJ_type, self.charge, scaler, nl_atom_numbers,
nl_atom_serial, self.LJ_A, self.LJ_B)
pme_excluded_force = self.pme_excluded_force(uint_crd, scaler, self.charge, self.excluded_list_start,
self.excluded_list, self.excluded_numbers)
pme_reciprocal_force = self.pme_reciprocal_force(uint_crd, self.charge)
force = P.AddN()(
[bond_force, angle_force, dihedral_force, nb14_force, lj_force, pme_excluded_force, pme_reciprocal_force])
return force
def Simulation_Caculate_Energy(self, uint_crd, uint_dr_to_dr_cof):
'''simulation calculate energy'''
bond_energy = self.bond_energy(uint_crd, uint_dr_to_dr_cof, self.bond_atom_a, self.bond_atom_b, self.bond_k,
self.bond_r0)
bond_energy_sum = P.ReduceSum(True)(bond_energy)
angle_energy = self.angle_energy(uint_crd, uint_dr_to_dr_cof, self.angle_atom_a, self.angle_atom_b,
self.angle_atom_c, self.angle_k, self.angle_theta0)
angle_energy_sum = P.ReduceSum(True)(angle_energy)
dihedral_energy = self.dihedral_energy(uint_crd, uint_dr_to_dr_cof, self.dihedral_atom_a, self.dihedral_atom_b,
self.dihedral_atom_c, self.dihedral_atom_d, self.ipn, self.pk, self.gamc,
self.gams, self.pn)
dihedral_energy_sum = P.ReduceSum(True)(dihedral_energy)
nb14_lj_energy = self.nb14_lj_energy(uint_crd, self.atom_LJ_type, self.charge, uint_dr_to_dr_cof,
self.nb14_atom_a, self.nb14_atom_b, self.lj_scale_factor, self.LJ_A,
self.LJ_B)
nb14_cf_energy = self.nb14_cf_energy(uint_crd, self.atom_LJ_type, self.charge, uint_dr_to_dr_cof,
self.nb14_atom_a, self.nb14_atom_b, self.cf_scale_factor)
nb14_lj_energy_sum = P.ReduceSum(True)(nb14_lj_energy)
nb14_cf_energy_sum = P.ReduceSum(True)(nb14_cf_energy)
lj_energy = self.lj_energy(uint_crd, self.atom_LJ_type, self.charge, uint_dr_to_dr_cof, self.nl_atom_numbers,
self.nl_atom_serial, self.LJ_A, self.LJ_B)
lj_energy_sum = P.ReduceSum(True)(lj_energy)
reciprocal_energy, self_energy, direct_energy, correction_energy = self.pme_energy(uint_crd, self.charge,
self.nl_atom_numbers,
self.nl_atom_serial,
uint_dr_to_dr_cof,
self.excluded_list_start,
self.excluded_list,
self.excluded_numbers)
ee_ene = reciprocal_energy + self_energy + direct_energy + correction_energy
total_energy = P.AddN()(
[bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum,
lj_energy_sum, ee_ene])
return bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
lj_energy_sum, ee_ene, total_energy
def Simulation_Temperature(self):
'''caculate temperature'''
res_ek_energy = self.mdtemp(self.res_start, self.res_end, self.velocity, self.mass)
temperature = P.ReduceSum()(res_ek_energy)
return temperature
def Simulation_MDIterationLeapFrog_Liujian(self, inverse_mass, sqrt_mass_inverse, crd, frc, rand_state, random_frc):
'''simulation leap frog iteration liujian'''
crd = self.md_iteration_leap_frog_liujian(inverse_mass, sqrt_mass_inverse, self.velocity, crd, frc, self.acc,
rand_state, random_frc)
vel = F.depend(self.velocity, crd)
acc = F.depend(self.acc, crd)
return vel, crd, acc
def construct(self, step, print_step):
'''construct'''
if step == 0:
res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
self.box_length, self.grid_N, self.grid_length_inverse,
self.atom_in_grid_serial, self.old_crd, self.crd_to_uint_crd_cof,
self.uint_crd, self.pointer, self.nl_atom_numbers, self.nl_atom_serial,
self.uint_dr_to_dr_cof, self.excluded_list_start, self.excluded_list,
self.excluded_numbers, self.need_refresh_flag, self.refresh_count)
self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
self.uint_dr_to_dr_cof = F.depend(self.uint_dr_to_dr_cof, res)
self.old_crd = F.depend(self.old_crd, res)
self.atom_numbers_in_grid_bucket = F.depend(self.atom_numbers_in_grid_bucket, res)
self.bucket = F.depend(self.bucket, res)
self.atom_in_grid_serial = F.depend(self.atom_in_grid_serial, res)
self.pointer = F.depend(self.pointer, res)
uint_crd = F.depend(self.uint_crd, res)
force = self.Simulation_Caculate_Force(uint_crd, self.uint_dr_to_dr_cof, self.nl_atom_numbers,
self.nl_atom_serial)
bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
lj_energy_sum, ee_ene, total_energy = self.Simulation_Caculate_Energy(uint_crd, self.uint_dr_to_dr_cof)
temperature = self.Simulation_Temperature()
self.rand_state = self.setup_random_state()
self.velocity, self.crd, _ = self.Simulation_MDIterationLeapFrog_Liujian(self.mass_inverse,
self.sqrt_mass, self.crd, force,
self.rand_state,
self.random_force)
res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket,
self.bucket,
self.crd,
self.box_length,
self.grid_N,
self.grid_length_inverse,
self.atom_in_grid_serial,
self.old_crd,
self.crd_to_uint_crd_cof,
self.uint_crd,
self.pointer,
self.nl_atom_numbers,
self.nl_atom_serial,
self.uint_dr_to_dr_cof,
self.excluded_list_start,
self.excluded_list,
self.excluded_numbers,
self.need_refresh_flag,
self.refresh_count)
self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
else:
uint_crd = self.Simulation_Beforce_Caculate_Force()
force = self.Simulation_Caculate_Force(uint_crd, self.uint_dr_to_dr_cof, self.nl_atom_numbers,
self.nl_atom_serial)
if print_step == 0:
bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, nb14_cf_energy_sum, \
lj_energy_sum, ee_ene, total_energy = self.Simulation_Caculate_Energy(
uint_crd, self.uint_dr_to_dr_cof)
else:
bond_energy_sum = self.zero_fp_tensor
angle_energy_sum = self.zero_fp_tensor
dihedral_energy_sum = self.zero_fp_tensor
nb14_lj_energy_sum = self.zero_fp_tensor
nb14_cf_energy_sum = self.zero_fp_tensor
lj_energy_sum = self.zero_fp_tensor
ee_ene = self.zero_fp_tensor
total_energy = self.zero_fp_tensor
temperature = self.Simulation_Temperature()
self.velocity, self.crd, _ = self.Simulation_MDIterationLeapFrog_Liujian(self.mass_inverse,
self.sqrt_mass, self.crd, force,
self.rand_state,
self.random_force)
res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket,
self.bucket,
self.crd,
self.box_length,
self.grid_N,
self.grid_length_inverse,
self.atom_in_grid_serial,
self.old_crd,
self.crd_to_uint_crd_cof,
self.uint_crd,
self.pointer,
self.nl_atom_numbers,
self.nl_atom_serial,
self.uint_dr_to_dr_cof,
self.excluded_list_start,
self.excluded_list,
self.excluded_numbers,
self.need_refresh_flag,
self.refresh_count)
self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
return temperature, total_energy, bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, \
nb14_cf_energy_sum, lj_energy_sum, ee_ene, res

View File

@ -1,245 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""simulation"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from .Langevin_Liujian_md import Langevin_Liujian
from .angle import Angle
from .bond import Bond
from .dihedral import Dihedral
from .lennard_jones import Lennard_Jones_Information
from .md_information import md_information
from .nb14 import NON_BOND_14
from .neighbor_list import nb_infomation
from .particle_mesh_ewald import Particle_Mesh_Ewald
class controller:
"""class controller"""
def __init__(self, args_opt):
self.input_file = args_opt.i
self.initial_coordinates_file = args_opt.c
self.amber_parm = args_opt.amber_parm
self.restrt = args_opt.r
self.mdcrd = args_opt.x
self.mdout = args_opt.o
self.mdbox = args_opt.box
self.Command_Set = {}
self.md_task = None
self.commands_from_in_file()
def commands_from_in_file(self):
"""commands from in file"""
file = open(self.input_file, 'r')
context = file.readlines()
file.close()
self.md_task = context[0].strip()
for val in context:
if "=" in val:
assert len(val.strip().split("=")) == 2
flag, value = val.strip().split("=")
value = value.replace(",", '')
flag = flag.replace(" ", "")
if flag not in self.Command_Set:
self.Command_Set[flag] = value
else:
print("ERROR COMMAND FILE")
class Simulation(nn.Cell):
"""class simulation"""
def __init__(self, args_opt):
super(Simulation, self).__init__()
self.control = controller(args_opt)
self.md_info = md_information(self.control)
self.bond = Bond(self.control, self.md_info)
self.angle = Angle(self.control)
self.dihedral = Dihedral(self.control)
self.nb14 = NON_BOND_14(self.control, self.dihedral, self.md_info.atom_numbers)
self.nb_info = nb_infomation(self.control, self.md_info.atom_numbers, self.md_info.box_length)
self.LJ_info = Lennard_Jones_Information(self.control)
self.liujian_info = Langevin_Liujian(self.control, self.md_info.atom_numbers)
self.pme_method = Particle_Mesh_Ewald(self.control, self.md_info)
self.box_length = Tensor(np.asarray(self.md_info.box_length, np.float32), mstype.float32)
self.file = None
def Main_Before_Calculate_Force(self):
"""Main Before Calculate Force"""
_ = self.md_info.MD_Information_Crd_To_Uint_Crd()
self.md_info.uint_crd_with_LJ = (self.md_info.uint_crd, self.LJ_info.atom_LJ_type, self.md_info.charge)
return self.md_info.uint_crd, self.md_info.uint_crd_with_LJ
def Initial_Neighbor_List_Update(self, not_first_time):
"""Initial Neighbor List Update"""
res = self.nb_info.NeighborListUpdate(self.md_info.crd, self.md_info.crd_old, self.md_info.uint_crd,
self.md_info.crd_to_uint_crd_cof, self.md_info.uint_dr_to_dr_cof,
self.box_length, not_first_time)
return res
def Main_Calculate_Force(self):
"""main calculate force"""
self.bond.atom_numbers = self.md_info.atom_numbers
md_info = self.md_info
LJ_info = self.LJ_info
nb_info = self.nb_info
pme_method = self.pme_method
bond_frc, _ = self.bond.Bond_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t = 0
frc_t += bond_frc.asnumpy()
angle_frc, _ = self.angle.Angle_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t += angle_frc.asnumpy()
dihedral_frc, _ = self.dihedral.Dihedral_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t += dihedral_frc.asnumpy()
nb14_frc, _ = self.nb14.Non_Bond_14_LJ_CF_Force_With_Atom_Energy(md_info.uint_crd_with_LJ,
md_info.uint_dr_to_dr_cof, LJ_info.LJ_A,
LJ_info.LJ_B)
frc_t += nb14_frc.asnumpy()
lj_frc = LJ_info.LJ_Force_With_PME_Direct_Force(
md_info.atom_numbers, md_info.uint_crd_with_LJ, md_info.uint_dr_to_dr_cof, nb_info.nl_atom_numbers,
nb_info.nl_atom_serial, nb_info.cutoff, pme_method.beta)
frc_t += lj_frc.asnumpy()
pme_excluded_frc = pme_method.PME_Excluded_Force(
md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge,
nb_info.excluded_list_start, nb_info.excluded_list,
nb_info.excluded_numbers, nb_info.excluded_atom_numbers)
frc_t += pme_excluded_frc.asnumpy()
pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge)
frc_t += pme_reciprocal_frc.asnumpy()
self.md_info.frc = Tensor(frc_t, mstype.float32)
return self.md_info.frc
def Main_Calculate_Energy(self):
"""main calculate energy"""
_ = self.bond.Bond_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.angle.Angle_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.dihedral.Dihedral_Engergy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.nb14.Non_Bond_14_LJ_CF_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
self.LJ_info.LJ_A,
self.LJ_info.LJ_B)
_ = self.LJ_info.LJ_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
self.nb_info.cutoff_square)
_ = self.pme_method.PME_Energy(
self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list,
self.nb_info.excluded_numbers, self.nb_info.excluded_atom_numbers)
_ = self.pme_method.Energy_Device_To_Host()
def Main_After_Calculate_Energy(self):
"""main after calculate energy"""
md_info = self.md_info
LJ_info = self.LJ_info
bond = self.bond
angle = self.angle
dihedral = self.dihedral
nb14 = self.nb14
pme_method = self.pme_method
md_info.total_potential_energy = 0
md_info.total_potential_energy += bond.sigma_of_bond_ene
md_info.total_potential_energy += angle.sigma_of_angle_ene
md_info.total_potential_energy += dihedral.sigma_of_dihedral_ene
md_info.total_potential_energy += nb14.nb14_lj_energy_sum + nb14.nb14_cf_energy_sum
md_info.total_potential_energy += LJ_info.LJ_energy_sum
pme_method.Energy_Device_To_Host()
md_info.total_potential_energy += pme_method.ee_ene
print("md_info.total_potential_energy", md_info.total_potential_energy)
def Main_Iteration_2(self):
"""main iteration2"""
md_info = self.md_info
control = self.control
liujian_info = self.liujian_info
if md_info.mode > 0 and int(control.Command_Set["thermostat"]) == 1:
md_info.vel, md_info.crd, md_info.frc, md_info.acc = liujian_info.MD_Iteration_Leap_Frog(
md_info.d_mass_inverse, md_info.vel, md_info.crd, md_info.frc)
self.Main_After_Iteration()
def Main_After_Iteration(self):
"""main after iteration"""
md_info = self.md_info
nb_info = self.nb_info
md_info.Centerize()
_ = nb_info.NeighborListUpdate(md_info.crd, md_info.crd_old, md_info.uint_crd,
md_info.crd_to_uint_crd_cof,
md_info.uint_dr_to_dr_cof, self.box_length, not_first_time=1)
def Main_Print(self):
"""compute the temperature"""
md_info = self.md_info
temperature = md_info.MD_Information_Temperature()
md_info.h_temperature = temperature
steps = md_info.steps
temperature = temperature.asnumpy()
total_potential_energy = md_info.total_potential_energy.asnumpy()
sigma_of_bond_ene = self.bond.sigma_of_bond_ene.asnumpy()
sigma_of_angle_ene = self.angle.sigma_of_angle_ene.asnumpy()
sigma_of_dihedral_ene = self.dihedral.sigma_of_dihedral_ene.asnumpy()
nb14_lj_energy_sum = self.nb14.nb14_lj_energy_sum.asnumpy()
nb14_cf_energy_sum = self.nb14.nb14_cf_energy_sum.asnumpy()
LJ_energy_sum = self.LJ_info.LJ_energy_sum.asnumpy()
ee_ene = self.pme_method.ee_ene.asnumpy()
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), end=" ")
if self.bond.bond_numbers > 0:
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
if self.angle.angle_numbers > 0:
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
if self.dihedral.dihedral_numbers > 0:
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
if self.nb14.nb14_numbers > 0:
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
print("{:>12.3f}".format(float(ee_ene)))
if self.file is not None:
self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
return temperature
def Main_Initial(self):
"""main initial"""
if self.control.mdout:
self.file = open(self.control.mdout, 'w')
self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
def Main_Destroy(self):
"""main destroy"""
if self.file is not None:
self.file.close()
print("Save successfully!")