From: @jiahongqian
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-17 21:31:12 +08:00 committed by Gitee
commit e789642e9a
13 changed files with 2602 additions and 606 deletions

View File

@ -87,7 +87,7 @@ class UncertaintyEvaluation:
self.epi_model = deepcopy(model)
self.ale_model = deepcopy(model)
self.epi_train_dataset = train_dataset
self.ale_train_dataset = train_dataset
self.ale_train_dataset = deepcopy(train_dataset)
self.task_type = task_type
self.epochs = Validator.check_positive_int(epochs)
self.epi_uncer_model_path = epi_uncer_model_path
@ -101,7 +101,8 @@ class UncertaintyEvaluation:
if not isinstance(model, Cell):
raise TypeError('The model should be Cell type.')
if task_type not in ('regression', 'classification'):
raise ValueError('The task should be regression or classification.')
raise ValueError(
'The task should be regression or classification.')
if task_type == 'classification':
self.num_classes = Validator.check_positive_int(num_classes)
else:
@ -119,15 +120,19 @@ class UncertaintyEvaluation:
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None:
if self.task_type == 'classification':
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_loss = SoftmaxCrossEntropyWithLogits(
sparse=True, reduction="mean")
net_opt = Adam(self.epi_uncer_model.trainable_params())
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
model = Model(self.epi_uncer_model, net_loss,
net_opt, metrics={"Accuracy": Accuracy()})
else:
net_loss = MSELoss()
net_opt = Adam(self.epi_uncer_model.trainable_params())
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model = Model(self.epi_uncer_model, net_loss,
net_opt, metrics={"MSE": MSE()})
if self.save_model:
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
config_ck = CheckpointConfig(
keep_checkpoint_max=self.epochs)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
directory=self.epi_uncer_model_path,
config=config_ck)
@ -137,7 +142,8 @@ class UncertaintyEvaluation:
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
uncer_param_dict = load_checkpoint(
self.epi_uncer_model_path)
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
def _eval_epistemic_uncertainty(self, eval_data, mc=10):
@ -164,15 +170,19 @@ class UncertaintyEvaluation:
Get the model which can obtain the aleatoric uncertainty.
"""
if self.ale_train_dataset is None:
raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.')
raise ValueError(
'The train dataset should not be None when evaluating aleatoric uncertainty.')
if self.ale_uncer_model is None:
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
self.ale_uncer_model = AleatoricUncertaintyModel(
self.ale_model, self.num_classes, self.task_type)
net_loss = AleatoricLoss(self.task_type)
net_opt = Adam(self.ale_uncer_model.trainable_params())
if self.task_type == 'classification':
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
model = Model(self.ale_uncer_model, net_loss,
net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model = Model(self.ale_uncer_model, net_loss,
net_opt, metrics={"MSE": MSE()})
if self.save_model:
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
@ -284,7 +294,8 @@ class AleatoricUncertaintyModel(Cell):
self.ale_model = ale_model
self.var_layer = Dense(num_classes, num_classes)
else:
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(
ale_model)
def construct(self, x):
if self.task == 'classification':
@ -327,7 +338,8 @@ class AleatoricLoss(Cell):
self.exp = P.Exp()
self.normal = C.normal
self.to_tensor = P.ScalarToArray()
self.entropy = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
self.entropy = SoftmaxCrossEntropyWithLogits(
sparse=True, reduction="mean")
else:
self.mean = P.ReduceMean()
self.exp = P.Exp()
@ -337,7 +349,8 @@ class AleatoricLoss(Cell):
y_pred, var = data_pred
if self.task == 'classification':
sample_times = 10
epsilon = self.normal((1, sample_times), self.to_tensor(0.0), self.to_tensor(1.0), 0)
epsilon = self.normal((1, sample_times), self.to_tensor(
0.0), self.to_tensor(1.0), 0)
total_loss = 0
for i in range(sample_times):
y_pred_i = y_pred + epsilon[0][i] * var
@ -345,5 +358,6 @@ class AleatoricLoss(Cell):
total_loss += loss
avg_loss = total_loss / sample_times
return avg_loss
loss = self.mean(0.5 * self.exp(-var) * self.pow(y - y_pred, 2) + 0.5 * var)
loss = self.mean(0.5 * self.exp(-var) *
self.pow(y - y_pred, 2) + 0.5 * var)
return loss

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main"""
import time
import argparse
from mindspore import context
from src.simulation_initial import Simulation
parser = argparse.ArgumentParser(description='Sponge Controller')
parser.add_argument('--i', type=str, default=None, help='input file')
parser.add_argument('--amber_parm', type=str, default=None,
help='paramter file in AMBER type')
parser.add_argument('--c', type=str, default=None,
help='initial coordinates file')
parser.add_argument('--r', type=str, default="restrt", help='')
parser.add_argument('--x', type=str, default="mdcrd", help='')
parser.add_argument('--o', type=str, default="mdout", help="")
parser.add_argument('--box', type=str, default="mdbox", help='')
args_opt = parser.parse_args()
context.set_context(mode=context.PYNATIVE_MODE,
device_target="GPU", device_id=0, save_graphs=True)
if __name__ == "__main__":
start = time.time()
simulation = Simulation(args_opt)
simulation.Main_Initial()
res = simulation.Initial_Neighbor_List_Update(not_first_time=0)
md_info = simulation.md_info
md_info.step_limit = 1
for i in range(1, md_info.step_limit + 1):
print("steps: ", i)
md_info.steps = i
simulation.Main_Before_Calculate_Force()
simulation.Main_Calculate_Force()
simulation.Main_Calculate_Energy()
simulation.Main_After_Calculate_Energy()
temperature = simulation.Main_Print()
simulation.Main_Iteration_2()
end = time.time()
print("Main time(s):", end - start)
simulation.Main_Destroy()

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Langevin Liujian MD class"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
class Langevin_Liujian:
"""Langevin_Liujian class"""
def __init__(self, controller, atom_numbers):
self.atom_numbers = atom_numbers
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.CONSTANT_TIME_CONVERTION = 20.455
self.CONSTANT_kB = 0.00198716
self.target_temperature = 300.0 if "target_temperature" not in controller.Command_Set else float(
controller.Command_Set["target_temperature"])
self.gamma_ln = 1.0 if "langevin_gamma" not in controller.Command_Set else float(
controller.Command_Set["langevin_gamma"])
self.rand_seed = 0 if "langevin_seed" not in controller.Command_Set else float(
controller.Command_Set["langevin_seed"]) # jiahong0315
self.max_velocity = 10000.0 if "velocity_max" not in controller.Command_Set else float(
controller.Command_Set["velocity_max"])
assert self.max_velocity > 0
self.is_max_velocity = 0 if "velocity_max" not in controller.Command_Set else 1
print("target temperature is ", self.target_temperature)
print("friction coefficient is ", self.gamma_ln, "ps^-1")
print("random seed is ", self.rand_seed)
self.dt = float(controller.Command_Set["dt"])
self.dt *= self.CONSTANT_TIME_CONVERTION
self.half_dt = 0.5 * self.dt
self.float4_numbers = math.ceil(3.0 * self.atom_numbers / 4.0)
self.gamma_ln = self.gamma_ln / self.CONSTANT_TIME_CONVERTION
self.exp_gamma = math.exp(-1 * self.gamma_ln * self.dt)
self.sqrt_gamma = math.sqrt((1. - self.exp_gamma * self.exp_gamma) * self.target_temperature * self.CONSTANT_kB)
self.h_sqrt_mass = [0] * self.atom_numbers
for i in range(self.atom_numbers):
self.h_sqrt_mass[i] = self.sqrt_gamma * math.sqrt(1. / self.h_mass[i])
self.d_sqrt_mass = Tensor(self.h_sqrt_mass, mstype.float32)
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.h_mass = [0] * self.atom_numbers
for idx, val in enumerate(context):
if "%FLAG MASS" in val:
count = 0
start_idx = idx
information = []
while count < self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.atom_numbers):
self.h_mass[i] = information[i]
break
def MDIterationLeapFrog_Liujian(self, atom_numbers, half_dt, dt, exp_gamma, inverse_mass, sqrt_mass_inverse, vel,
crd, frc, random_frc):
"""compute MDIterationLeapFrog Liujian"""
inverse_mass = inverse_mass.reshape((-1, 1))
sqrt_mass_inverse = sqrt_mass_inverse.reshape((-1, 1))
acc = inverse_mass * frc
vel = vel + dt * acc
crd = crd + half_dt * vel
vel = exp_gamma * vel + sqrt_mass_inverse * random_frc
crd = crd + half_dt * vel
frc = Tensor(np.zeros((atom_numbers, 3)), mstype.float32)
return vel, crd, frc, acc
def MD_Iteration_Leap_Frog(self, d_mass_inverse, vel_in, crd_in, frc_in):
"""MD_Iteration_Leap_Frog"""
np.random.seed(int(self.rand_seed))
self.rand_force = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
# self.rand_force = Tensor(np.random.randn(self.atom_numbers, 3), mstype.float32)
vel, crd, frc, acc = self.MDIterationLeapFrog_Liujian(atom_numbers=self.atom_numbers, half_dt=self.half_dt,
dt=self.dt, exp_gamma=self.exp_gamma,
inverse_mass=d_mass_inverse,
sqrt_mass_inverse=self.d_sqrt_mass,
vel=vel_in, crd=crd_in,
frc=frc_in, random_frc=self.rand_force)
return vel, crd, frc, acc

View File

@ -0,0 +1,175 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""angle class"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Angle(nn.Cell):
"""Angle class"""
def __init__(self, controller):
super(Angle, self).__init__()
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
self.angle_k = Tensor(np.asarray(self.h_angle_k, np.float32), mstype.float32)
self.angle_theta0 = Tensor(np.asarray(self.h_angle_theta0, np.float32), mstype.float32)
def read_process1(self, context):
"""read_information_from_amberfile process1"""
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
count = 0
value = list(map(int, context[start_idx].strip().split()))
self.angle_with_H_numbers = value[4]
self.angle_without_H_numbers = value[5]
self.angle_numbers = self.angle_with_H_numbers + self.angle_without_H_numbers
information = []
information.extend(value)
while count < 15:
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.angle_type_numbers = information[16]
print("angle type numbers ", self.angle_type_numbers)
break
def read_process2(self, context):
"""read_information_from_amberfile process2"""
angle_count = 0
for idx, val in enumerate(context):
if "%FLAG ANGLES_INC_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 4 * self.angle_with_H_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for _ in range(self.angle_with_H_numbers):
self.h_atom_a[angle_count] = information[angle_count * 4 + 0] / 3
self.h_atom_b[angle_count] = information[angle_count * 4 + 1] / 3
self.h_atom_c[angle_count] = information[angle_count * 4 + 2] / 3
self.h_type[angle_count] = information[angle_count * 4 + 3] - 1
angle_count += 1
break
return angle_count
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.read_process1(context)
self.h_atom_a = [0] * self.angle_numbers
self.h_atom_b = [0] * self.angle_numbers
self.h_atom_c = [0] * self.angle_numbers
self.h_type = [0] * self.angle_numbers
angle_count = self.read_process2(context)
for idx, val in enumerate(context):
if "%FLAG ANGLES_WITHOUT_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 4 * self.angle_without_H_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.angle_without_H_numbers):
self.h_atom_a[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 0] / 3
self.h_atom_b[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 1] / 3
self.h_atom_c[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 2] / 3
self.h_type[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 3] - 1
angle_count += 1
break
self.type_k = [0] * self.angle_type_numbers
for idx, val in enumerate(context):
if "%FLAG ANGLE_FORCE_CONSTANT" in val:
count = 0
start_idx = idx
information = []
while count < self.angle_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
# print(start_idx)
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.type_k = information[:self.angle_type_numbers]
break
self.type_theta0 = [0] * self.angle_type_numbers
for idx, val in enumerate(context):
if "%FLAG ANGLE_EQUIL_VALUE" in val:
count = 0
start_idx = idx
information = []
while count < self.angle_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.type_theta0 = information[:self.angle_type_numbers]
break
if self.angle_numbers != angle_count:
print("angle count %d != angle_number %d ", angle_count, self.angle_numbers)
self.h_angle_k = []
self.h_angle_theta0 = []
for i in range(self.angle_numbers):
self.h_angle_k.append(self.type_k[self.h_type[i]])
self.h_angle_theta0.append(self.type_theta0[self.h_type[i]])
def Angle_Energy(self, uint_crd, uint_dr_to_dr_cof):
"""compute angle energy"""
self.angle_energy = P.AngleEnergy(self.angle_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a, self.atom_b,
self.atom_c, self.angle_k, self.angle_theta0)
self.sigma_of_angle_ene = P.ReduceSum()(self.angle_energy)
return self.sigma_of_angle_ene
def Angle_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute angle force with atom energy"""
print("angele angle numbers:", self.angle_numbers)
self.afae = P.AngleForceWithAtomEnergy(angle_numbers=self.angle_numbers)
frc, ene = self.afae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.angle_k, self.angle_theta0)
return frc, ene

View File

@ -0,0 +1,163 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""bond class"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Bond(nn.Cell):
"""bond class"""
def __init__(self, controller, md_info):
super(Bond, self).__init__()
self.atom_numbers = md_info.atom_numbers
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.bond_k = Tensor(np.asarray(self.h_k, np.float32), mstype.float32)
self.bond_r0 = Tensor(np.asarray(self.h_r0, np.float32), mstype.float32)
def process1(self, context):
"""process1: read information from amberfile"""
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
count = 0
value = list(map(int, context[start_idx].strip().split()))
self.bond_with_hydrogen = value[2]
self.bond_numbers = value[3]
self.bond_numbers += self.bond_with_hydrogen
print(self.bond_numbers)
information = []
information.extend(value)
while count < 16:
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.bond_type_numbers = information[15]
print("bond type numbers ", self.bond_type_numbers)
break
for idx, val in enumerate(context):
if "%FLAG BOND_FORCE_CONSTANT" in val:
count = 0
start_idx = idx
information = []
while count < self.bond_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.bond_type_k = information[:self.bond_type_numbers]
break
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.process1(context)
for idx, val in enumerate(context):
if "%FLAG BOND_EQUIL_VALUE" in val:
count = 0
start_idx = idx
information = []
while count < self.bond_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.bond_type_r = information[:self.bond_type_numbers]
break
for idx, val in enumerate(context):
if "%FLAG BONDS_INC_HYDROGEN" in val:
self.h_atom_a = [0] * self.bond_numbers
self.h_atom_b = [0] * self.bond_numbers
self.h_k = [0] * self.bond_numbers
self.h_r0 = [0] * self.bond_numbers
count = 0
start_idx = idx
information = []
while count < 3 * self.bond_with_hydrogen:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.bond_with_hydrogen):
self.h_atom_a[i] = information[3 * i + 0] / 3
self.h_atom_b[i] = information[3 * i + 1] / 3
tmpi = information[3 * i + 2] - 1
self.h_k[i] = self.bond_type_k[tmpi]
self.h_r0[i] = self.bond_type_r[tmpi]
break
for idx, val in enumerate(context):
if "%FLAG BONDS_WITHOUT_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 3 * (self.bond_numbers - self.bond_with_hydrogen):
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.bond_with_hydrogen, self.bond_numbers):
self.h_atom_a[i] = information[3 * (i - self.bond_with_hydrogen) + 0] / 3
self.h_atom_b[i] = information[3 * (i - self.bond_with_hydrogen) + 1] / 3
tmpi = information[3 * (i - self.bond_with_hydrogen) + 2] - 1
self.h_k[i] = self.bond_type_k[tmpi]
self.h_r0[i] = self.bond_type_r[tmpi]
break
def Bond_Energy(self, uint_crd, uint_dr_to_dr_cof):
"""compute bond energy"""
self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.bond_k, self.bond_r0)
self.sigma_of_bond_ene = P.ReduceSum()(self.bond_energy)
return self.sigma_of_bond_ene
def Bond_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute bond force with atom energy"""
self.bfatomenergy = P.BondForceWithAtomEnergy(bond_numbers=self.bond_numbers,
atom_numbers=self.atom_numbers)
frc, atom_energy = self.bfatomenergy(uint_crd, scaler, self.atom_a, self.atom_b, self.bond_k, self.bond_r0)
return frc, atom_energy

View File

@ -0,0 +1,221 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dihedral class"""
import math
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Dihedral(nn.Cell):
"""dihedral class"""
def __init__(self, controller):
super(Dihedral, self).__init__()
self.CONSTANT_Pi = 3.1415926535897932
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
self.atom_d = Tensor(np.asarray(self.h_atom_d, np.int32), mstype.int32)
self.pk = Tensor(np.asarray(self.pk, np.float32), mstype.float32)
self.gamc = Tensor(np.asarray(self.gamc, np.float32), mstype.float32)
self.gams = Tensor(np.asarray(self.gams, np.float32), mstype.float32)
self.pn = Tensor(np.asarray(self.pn, np.float32), mstype.float32)
self.ipn = Tensor(np.asarray(self.ipn, np.int32), mstype.int32)
def process1(self, context):
"""process1: read information from amberfile"""
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
count = 0
value = list(map(int, context[start_idx].strip().split()))
self.dihedral_with_hydrogen = value[6]
self.dihedral_numbers = value[7]
self.dihedral_numbers += self.dihedral_with_hydrogen
information = []
information.extend(value)
while count < 15:
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.dihedral_type_numbers = information[17]
print("dihedral type numbers ", self.dihedral_type_numbers)
break
self.phase_type = [0] * self.dihedral_type_numbers
self.pk_type = [0] * self.dihedral_type_numbers
self.pn_type = [0] * self.dihedral_type_numbers
for idx, val in enumerate(context):
if "%FLAG DIHEDRAL_FORCE_CONSTANT" in val:
count = 0
start_idx = idx
information = []
while count < self.dihedral_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.pk_type = information[:self.dihedral_type_numbers]
break
for idx, val in enumerate(context):
if "%FLAG DIHEDRAL_PHASE" in val:
count = 0
start_idx = idx
information = []
while count < self.dihedral_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.phase_type = information[:self.dihedral_type_numbers]
break
for idx, val in enumerate(context):
if "%FLAG DIHEDRAL_PERIODICITY" in val:
count = 0
start_idx = idx
information = []
while count < self.dihedral_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.pn_type = information[:self.dihedral_type_numbers]
break
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.process1(context)
self.h_atom_a = [0] * self.dihedral_numbers
self.h_atom_b = [0] * self.dihedral_numbers
self.h_atom_c = [0] * self.dihedral_numbers
self.h_atom_d = [0] * self.dihedral_numbers
self.pk = []
self.gamc = []
self.gams = []
self.pn = []
self.ipn = []
for idx, val in enumerate(context):
if "%FLAG DIHEDRALS_INC_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 5 * self.dihedral_with_hydrogen:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.dihedral_with_hydrogen):
self.h_atom_a[i] = information[i * 5 + 0] / 3
self.h_atom_b[i] = information[i * 5 + 1] / 3
self.h_atom_c[i] = information[i * 5 + 2] / 3
self.h_atom_d[i] = abs(information[i * 5 + 3] / 3)
tmpi = information[i * 5 + 4] - 1
self.pk.append(self.pk_type[tmpi])
tmpf = self.phase_type[tmpi]
if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
tmpf = self.CONSTANT_Pi
tmpf2 = math.cos(tmpf)
if abs(tmpf2) < 1e-6:
tmpf2 = 0
self.gamc.append(tmpf2 * self.pk[i])
tmpf2 = math.sin(tmpf)
if abs(tmpf2) < 1e-6:
tmpf2 = 0
self.gams.append(tmpf2 * self.pk[i])
self.pn.append(abs(self.pn_type[tmpi]))
self.ipn.append(int(self.pn[i] + 0.001))
break
for idx, val in enumerate(context):
if "%FLAG DIHEDRALS_WITHOUT_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 5 * (self.dihedral_numbers - self.dihedral_with_hydrogen):
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.dihedral_with_hydrogen, self.dihedral_numbers):
self.h_atom_a[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 0] / 3
self.h_atom_b[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 1] / 3
self.h_atom_c[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 2] / 3
self.h_atom_d[i] = abs(information[(i - self.dihedral_with_hydrogen) * 5 + 3] / 3)
tmpi = information[(i - self.dihedral_with_hydrogen) * 5 + 4] - 1
self.pk.append(self.pk_type[tmpi])
tmpf = self.phase_type[tmpi]
if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
tmpf = self.CONSTANT_Pi
tmpf2 = math.cos(tmpf)
if abs(tmpf2) < 1e-6:
tmpf2 = 0
self.gamc.append(tmpf2 * self.pk[i])
tmpf2 = math.sin(tmpf)
if abs(tmpf2) < 1e-6:
tmpf2 = 0
self.gams.append(tmpf2 * self.pk[i])
self.pn.append(abs(self.pn_type[tmpi]))
self.ipn.append(int(self.pn[i] + 0.001))
break
for i in range(self.dihedral_numbers):
if self.h_atom_c[i] < 0:
self.h_atom_c[i] *= -1
def Dihedral_Engergy(self, uint_crd, uint_dr_to_dr_cof):
"""compute dihedral energy"""
self.dihedral_energy = P.DihedralEnergy(self.dihedral_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.atom_c, self.atom_d, self.ipn,
self.pk, self.gamc, self.gams, self.pn)
self.sigma_of_dihedral_ene = P.ReduceSum()(self.dihedral_energy)
return self.sigma_of_dihedral_ene
def Dihedral_Force_With_Atom_Energy(self, uint_crd, scaler):
"""compute dihedral force and atom energy"""
self.dfae = P.DihedralForceWithAtomEnergy(dihedral_numbers=self.dihedral_numbers)
self.frc, self.ene = self.dfae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.atom_d,
self.ipn, self.pk, self.gamc, self.gams, self.pn)
return self.frc, self.ene

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lennard jones"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Lennard_Jones_Information(nn.Cell):
"""class Lennard Jones Information"""
def __init__(self, controller):
super(Lennard_Jones_Information, self).__init__()
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_LJ_type = Tensor(np.asarray(self.atom_LJ_type, dtype=np.int32), mstype.int32)
self.LJ_A = Tensor(np.asarray(self.LJ_A, dtype=np.float32), mstype.float32)
self.LJ_B = Tensor(np.asarray(self.LJ_B, dtype=np.float32), mstype.float32)
self.LJ_energy_sum = 0
self.LJ_energy = 0
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
count = 0
value = list(map(int, context[start_idx].strip().split()))
self.atom_numbers = value[0]
self.atom_type_numbers = value[1]
self.pair_type_numbers = int(self.atom_type_numbers * (self.atom_type_numbers + 1) / 2)
print(self.pair_type_numbers)
break
self.atom_LJ_type = [0] * self.atom_numbers
for idx, val in enumerate(context):
if "%FLAG ATOM_TYPE_INDEX" in val:
count = 0
start_idx = idx
information = []
while count < self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.atom_numbers):
self.atom_LJ_type[i] = information[i] - 1
break
self.LJ_A = [0] * self.pair_type_numbers
for idx, val in enumerate(context):
if "%FLAG LENNARD_JONES_ACOEF" in val:
count = 0
start_idx = idx
information = []
while count < self.pair_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.pair_type_numbers):
self.LJ_A[i] = 12.0 * information[i]
break
self.LJ_B = [0] * self.pair_type_numbers
for idx, val in enumerate(context):
if "%FLAG LENNARD_JONES_BCOEF" in val:
count = 0
start_idx = idx
information = []
while count < self.pair_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.pair_type_numbers):
self.LJ_B[i] = 6.0 * information[i]
break
def LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, cutoff_square):
"""compute LJ energy"""
uint_crd, LJtype, charge = uint_crd_with_LJ
self.LJ_energy = P.LJEnergy(self.atom_numbers, cutoff_square) \
(uint_crd, LJtype, charge, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, self.LJ_A, self.LJ_B)
self.LJ_energy_sum = P.ReduceSum()(self.LJ_energy)
return self.LJ_energy_sum
def LJ_Force_With_PME_Direct_Force(self, atom_numbers, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_number, nl_serial,
cutoff, beta):
"""compute LJ force with PME direct force"""
assert atom_numbers == self.atom_numbers
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd_f, LJtype, charge = uint_crd_with_LJ
self.ljfd = P.LJForceWithPMEDirectForce(atom_numbers, cutoff, beta)
frc = self.ljfd(uint_crd_f, LJtype, charge, uint_dr_to_dr_cof, nl_number, nl_serial, self.LJ_A, self.LJ_B)
return frc

View File

@ -0,0 +1,212 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""md information"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class md_information(nn.Cell):
"""class md information"""
def __init__(self, controller):
super(md_information, self).__init__()
CONSTANT_TIME_CONVERTION = 20.455
CONSTANT_UINT_MAX_FLOAT = 4294967296.0
self.md_task = controller.md_task
self.mode = 0 if "mode" not in controller.Command_Set else int(controller.Command_Set["mode"])
self.dt = 0.001 * CONSTANT_TIME_CONVERTION if "dt" not in controller.Command_Set \
else float(controller.Command_Set["dt"]) * CONSTANT_TIME_CONVERTION
self.skin = 2.0 if "skin" not in controller.Command_Set \
else float(controller.Command_Set["skin"])
self.trans_vec = [self.skin, self.skin, self.skin]
self.trans_vec_minus = -1 * self.trans_vec
self.step_limit = 1000 if "step_limit" not in controller.Command_Set else int(
controller.Command_Set["step_limit"])
self.netfrc = 0 if "net_force" not in controller.Command_Set else int(controller.Command_Set["net_force"])
self.ntwx = 1000 if "write_information_interval" not in controller.Command_Set else \
int(controller.Command_Set["write_information_interval"])
self.ntce = self.step_limit + 1 if "calculate_energy_interval" not in controller.Command_Set else \
int(controller.Command_Set["calculate_energy_interval"])
self.atom_numbers = 0
self.residue_numbers = 0
self.density = 0.0
self.lin_serial = []
self.h_res_start = []
self.h_res_end = []
self.h_mass = []
self.h_mass_inverse = []
self.h_charge = []
self.steps = 0
if controller.amber_parm is not None:
self.read_basic_system_information_from_amber_file(controller.amber_parm)
if "amber_irest" in controller.Command_Set:
amber_irest = int(controller.Command_Set["amber_irest"])
if controller.initial_coordinates_file is not None:
self.read_basic_system_information_from_rst7(controller.initial_coordinates_file, amber_irest)
self.crd_to_uint_crd_cof = [CONSTANT_UINT_MAX_FLOAT / self.box_length[0],
CONSTANT_UINT_MAX_FLOAT / self.box_length[1],
CONSTANT_UINT_MAX_FLOAT / self.box_length[2]]
self.uint_dr_to_dr_cof = [1.0 / self.crd_to_uint_crd_cof[0], 1.0 / self.crd_to_uint_crd_cof[1],
1.0 / self.crd_to_uint_crd_cof[2]]
self.density *= 1e24 / 6.023e23 / (self.box_length[0] * self.box_length[1] * self.box_length[2])
self.frc = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
self.crd = Tensor(np.array(self.coordinate, dtype=np.float32).reshape((self.atom_numbers, 3)), mstype.float32)
self.crd_n = np.array(self.coordinate).reshape([self.atom_numbers, 3])
self.crd_old = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32)
self.uint_crd = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32)
self.charge = Tensor(self.h_charge, mstype.float32)
self.crd_to_uint_crd_cof_n = np.array(self.crd_to_uint_crd_cof)
self.crd_to_uint_crd_cof = Tensor(self.crd_to_uint_crd_cof, mstype.float32)
self.uint_dr_to_dr_cof = Tensor(self.uint_dr_to_dr_cof, mstype.float32)
self.uint_crd_with_LJ = None
self.d_mass_inverse = Tensor(self.h_mass_inverse, mstype.float32)
self.d_res_start = Tensor(self.h_res_start, mstype.int32)
self.d_res_end = Tensor(self.h_res_end, mstype.int32)
self.d_mass = Tensor(self.h_mass, mstype.float32)
def process1(self, context):
"""process1: read basic system information from amber file"""
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
value = list(map(int, context[start_idx].strip().split()))
self.atom_numbers = value[0]
count = len(value) - 1
while count < 10:
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
count += len(value)
self.residue_numbers = list(map(int, context[start_idx].strip().split()))[10 - (count - 10)]
break
def read_basic_system_information_from_amber_file(self, path):
"""read basic system information from amber file"""
file = open(path, 'r')
context = file.readlines()
file.close()
self.process1(context)
if self.residue_numbers != 0 and self.atom_numbers != 0:
for idx, val in enumerate(context):
if "%FLAG RESIDUE_POINTER" in val:
count = 0
start_idx = idx
while count != self.residue_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
self.lin_serial.extend(value)
count += len(value)
for i in range(self.residue_numbers - 1):
self.h_res_start.append(self.lin_serial[i] - 1)
self.h_res_end.append(self.lin_serial[i + 1] - 1)
self.h_res_start.append(self.lin_serial[-1] - 1)
self.h_res_end.append(self.atom_numbers + 1 - 1)
break
for idx, val in enumerate(context):
if "%FLAG MASS" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_mass.extend(value)
count += len(value)
for i in range(self.atom_numbers):
if self.h_mass[i] == 0:
self.h_mass_inverse.append(0.0)
else:
self.h_mass_inverse.append(1.0 / self.h_mass[i])
self.density += self.h_mass[i]
break
for idx, val in enumerate(context):
if "%FLAG CHARGE" in val:
count = 0
start_idx = idx
while count != self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
self.h_charge.extend(value)
count += len(value)
break
def read_basic_system_information_from_rst7(self, path, irest):
"""read basic system information from rst7"""
file = open(path, 'r')
context = file.readlines()
file.close()
atom_numbers = int(context[1].strip().split()[0])
if atom_numbers != self.atom_numbers:
print("ERROR")
else:
print("check atom_numbers")
information = []
count = 0
start_idx = 1
if irest == 1:
self.simulation_start_time = float(context[1].strip().split()[1])
while count <= 6 * self.atom_numbers + 3:
start_idx += 1
# print(start_idx)
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.coordinate = information[: 3 * self.atom_numbers]
self.velocity = information[3 * self.atom_numbers: 6 * self.atom_numbers]
self.box_length = information[6 * self.atom_numbers:6 * self.atom_numbers + 3]
else:
while count <= 3 * self.atom_numbers + 3:
start_idx += 1
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.coordinate = information[: 3 * self.atom_numbers]
self.velocity = [0.0] * (3 * self.atom_numbers)
self.box_length = information[3 * self.atom_numbers:3 * self.atom_numbers + 3]
self.vel = Tensor(self.velocity, mstype.float32).reshape((self.atom_numbers, 3))
self.acc = Tensor(np.zeros((self.atom_numbers, 3), dtype=np.float32), mstype.float32)
def MD_Information_Crd_To_Uint_Crd(self):
"""transform the crd to uint crd"""
uint_crd = self.crd.asnumpy() * (0.5 * self.crd_to_uint_crd_cof.asnumpy()) * 2
self.uint_crd = Tensor(uint_crd, mstype.uint32)
return self.uint_crd
def Centerize(self):
return
def MD_Information_Temperature(self):
"""compute temperature"""
self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers)
self.res_ek_energy = self.mdtemp(self.d_res_start, self.d_res_end, self.vel, self.d_mass)
self.d_temperature = P.ReduceSum()(self.res_ek_energy)
return self.d_temperature

View File

@ -0,0 +1,195 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""nb14"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class NON_BOND_14(nn.Cell):
"""class Non bond 14"""
def __init__(self, controller, dihedral, atom_numbers):
super(NON_BOND_14, self).__init__()
self.dihedral_with_hydrogen = dihedral.dihedral_with_hydrogen
self.dihedral_numbers = dihedral.dihedral_numbers
self.dihedral_type_numbers = dihedral.dihedral_type_numbers
self.atom_numbers = atom_numbers
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
self.lj_scale_factor = Tensor(np.asarray(self.h_lj_scale_factor, np.float32), mstype.float32)
self.cf_scale_factor = Tensor(np.asarray(self.h_cf_scale_factor, np.float32), mstype.float32)
def process1(self, context):
"""process1: read information from amberfile"""
for idx, val in enumerate(context):
if "%FLAG SCEE_SCALE_FACTOR" in val:
count = 0
start_idx = idx
information = []
while count < self.dihedral_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.cf_scale_type = information[:self.dihedral_type_numbers]
break
for idx, val in enumerate(context):
if "%FLAG SCNB_SCALE_FACTOR" in val:
count = 0
start_idx = idx
information = []
while count < self.dihedral_type_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(float, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.lj_scale_type = information[:self.dihedral_type_numbers]
break
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.cf_scale_type = [0] * self.dihedral_type_numbers
self.lj_scale_type = [0] * self.dihedral_type_numbers
self.h_atom_a = [0] * self.dihedral_numbers
self.h_atom_b = [0] * self.dihedral_numbers
self.h_lj_scale_factor = [0] * self.dihedral_numbers
self.h_cf_scale_factor = [0] * self.dihedral_numbers
nb14_numbers = 0
for idx, val in enumerate(context):
if "%FLAG DIHEDRALS_INC_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 5 * self.dihedral_with_hydrogen:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.dihedral_with_hydrogen):
tempa = information[i * 5 + 0]
tempi = information[i * 5 + 1]
tempi2 = information[i * 5 + 2]
tempb = information[i * 5 + 3]
tempi = information[i * 5 + 4]
tempi -= 1
if tempi2 > 0:
self.h_atom_a[nb14_numbers] = tempa / 3
self.h_atom_b[nb14_numbers] = abs(tempb / 3)
self.h_lj_scale_factor[nb14_numbers] = self.lj_scale_type[tempi]
if self.h_lj_scale_factor[nb14_numbers] != 0:
self.h_lj_scale_factor[nb14_numbers] = 1.0 / self.h_lj_scale_factor[nb14_numbers]
self.h_cf_scale_factor[nb14_numbers] = self.cf_scale_type[tempi]
if self.h_cf_scale_factor[nb14_numbers] != 0:
self.h_cf_scale_factor[nb14_numbers] = 1.0 / self.h_cf_scale_factor[nb14_numbers]
nb14_numbers += 1
break
for idx, val in enumerate(context):
if "%FLAG DIHEDRALS_WITHOUT_HYDROGEN" in val:
count = 0
start_idx = idx
information = []
while count < 5 * (self.dihedral_numbers - self.dihedral_with_hydrogen):
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
for i in range(self.dihedral_with_hydrogen, self.dihedral_numbers):
tempa = information[(i - self.dihedral_with_hydrogen) * 5 + 0]
tempi = information[(i - self.dihedral_with_hydrogen) * 5 + 1]
tempi2 = information[(i - self.dihedral_with_hydrogen) * 5 + 2]
tempb = information[(i - self.dihedral_with_hydrogen) * 5 + 3]
tempi = information[(i - self.dihedral_with_hydrogen) * 5 + 4]
tempi -= 1
if tempi2 > 0:
self.h_atom_a[nb14_numbers] = tempa / 3
self.h_atom_b[nb14_numbers] = abs(tempb / 3)
self.h_lj_scale_factor[nb14_numbers] = self.lj_scale_type[tempi]
if self.h_lj_scale_factor[nb14_numbers] != 0:
self.h_lj_scale_factor[nb14_numbers] = 1.0 / self.h_lj_scale_factor[nb14_numbers]
self.h_cf_scale_factor[nb14_numbers] = self.cf_scale_type[tempi]
if self.h_cf_scale_factor[nb14_numbers] != 0:
self.h_cf_scale_factor[nb14_numbers] = 1.0 / self.h_cf_scale_factor[nb14_numbers]
nb14_numbers += 1
break
self.nb14_numbers = nb14_numbers
def Non_Bond_14_LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
"""compute Non bond 14 LJ energy"""
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd, LJtype, charge = uint_crd_with_LJ
self.LJ_energy = P.Dihedral14LJEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.lj_scale_factor,
LJ_A, LJ_B)
self.nb14_lj_energy_sum = P.ReduceSum()(self.LJ_energy)
return self.nb14_lj_energy_sum
def Non_Bond_14_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof):
"""compute Non bond 14 CF energy"""
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd, LJtype, charge = uint_crd_with_LJ
self.CF_energy = P.Dihedral14CFEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
uint_dr_to_dr_cof, self.atom_a,
self.atom_b, self.cf_scale_factor)
self.nb14_cf_energy_sum = P.ReduceSum()(self.CF_energy)
return self.nb14_cf_energy_sum
def Non_Bond_14_LJ_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
"""compute Non bond 14 LJ and CF energy"""
assert isinstance(uint_crd_with_LJ, tuple)
self.nb14_lj_energy_sum = self.Non_Bond_14_LJ_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B)
self.nb14_cf_energy_sum = self.Non_Bond_14_CF_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof)
return self.nb14_lj_energy_sum, self.nb14_cf_energy_sum
def Non_Bond_14_LJ_CF_Force_With_Atom_Energy(self, uint_crd_with_LJ, boxlength, LJ_A, LJ_B):
"""compute Non bond 14 LJ CF force and atom energy"""
self.d14lj = P.Dihedral14LJCFForceWithAtomEnergy(nb14_numbers=self.nb14_numbers, atom_numbers=self.atom_numbers)
assert isinstance(uint_crd_with_LJ, tuple)
uint_crd_f, LJtype, charge = uint_crd_with_LJ
self.frc, self.atom_ene = self.d14lj(uint_crd_f, LJtype, charge, boxlength, self.atom_a, self.atom_b,
self.lj_scale_factor, self.cf_scale_factor, LJ_A, LJ_B)
return self.frc, self.atom_ene

View File

@ -0,0 +1,207 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""neighbour list"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class nb_infomation(nn.Cell):
"""neighbour list"""
def __init__(self, controller, atom_numbers, box_length):
super(nb_infomation, self).__init__()
self.refresh_interval = 20 if "neighbor_list_refresh_interval" not in controller.Command_Set else \
int(controller.Command_Set["neighbor_list_refresh_interval"])
self.max_atom_in_grid_numbers = 64 if "max_atom_in_grid_numbers" not in controller.Command_Set else \
int(controller.Command_Set["max_atom_in_grid_numbers"])
self.max_neighbor_numbers = 800 if "max_neighbor_numbers" not in controller.Command_Set else \
int(controller.Command_Set["max_neighbor_numbers"])
self.skin = 2.0 if "skin" not in controller.Command_Set else float(controller.Command_Set["skin"])
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
self.cutoff_square = self.cutoff * self.cutoff
self.cutoff_with_skin = self.cutoff + self.skin
self.half_cutoff_with_skin = 0.5 * self.cutoff_with_skin
self.cutoff_with_skin_square = self.cutoff_with_skin * self.cutoff_with_skin
self.half_skin_square = 0.25 * self.skin * self.skin
self.atom_numbers = atom_numbers
self.box_length = box_length
if controller.amber_parm is not None:
file_path = controller.amber_parm
self.read_information_from_amberfile(file_path)
self.Initial_Neighbor_Grid()
self.not_first_time = 0
self.refresh_count = 0
self.atom_numbers_in_grid_bucket = Tensor(np.asarray(self.atom_numbers_in_grid_bucket, np.int32), mstype.int32)
self.bucket = Tensor(
np.asarray(self.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]), mstype.int32)
self.grid_N = Tensor(np.asarray(self.grid_N, np.int32), mstype.int32)
self.grid_length_inverse = Tensor(np.asarray(self.grid_length_inverse, np.float32), mstype.float32)
self.atom_in_grid_serial = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
self.pointer = Tensor(np.asarray(self.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32)
self.nl_atom_numbers = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
self.nl_atom_serial = Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32)
self.excluded_list_start = Tensor(np.asarray(self.excluded_list_start, np.int32), mstype.int32)
self.excluded_list = Tensor(np.asarray(self.excluded_list, np.int32), mstype.int32)
self.excluded_numbers = Tensor(np.asarray(self.excluded_numbers, np.int32), mstype.int32)
self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
def read_information_from_amberfile(self, file_path):
"""read information from amberfile"""
file = open(file_path, 'r')
context = file.readlines()
file.close()
self.excluded_list_start = [0] * self.atom_numbers
self.excluded_numbers = [0] * self.atom_numbers
for idx, val in enumerate(context):
if idx < len(context) - 1:
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
start_idx = idx + 2
count = 0
value = list(map(int, context[start_idx].strip().split()))
information = []
information.extend(value)
while count < 11:
start_idx += 1
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
self.excluded_atom_numbers = information[10]
break
for idx, val in enumerate(context):
if "%FLAG NUMBER_EXCLUDED_ATOMS" in val:
count = 0
start_idx = idx
information = []
while count < self.atom_numbers:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
count = 0
for i in range(self.atom_numbers):
self.excluded_numbers[i] = information[i]
self.excluded_list_start[i] = count
count += information[i]
break
total_count = sum(self.excluded_numbers)
self.excluded_list = []
for idx, val in enumerate(context):
if "%FLAG EXCLUDED_ATOMS_LIST" in val:
count = 0
start_idx = idx
information = []
while count < total_count:
start_idx += 1
if "%FORMAT" in context[start_idx]:
continue
else:
value = list(map(int, context[start_idx].strip().split()))
information.extend(value)
count += len(value)
count = 0
for i in range(self.atom_numbers):
tmp_list = []
for _ in range(self.excluded_numbers[i]):
tmp_list.append(information[count] - 1)
count += 1
tmp_list = sorted(tmp_list)
self.excluded_list.extend(tmp_list)
break
def fun(self, Nx, Ny, Nz, l, m, temp_grid_serial, count):
"""fun to replace the for"""
for n in range(-2, 3):
xx = Nx + l
if xx < 0:
xx = xx + self.Nx
elif xx >= self.Nx:
xx = xx - self.Nx
yy = Ny + m
if yy < 0:
yy = yy + self.Ny
elif yy >= self.Ny:
yy = yy - self.Ny
zz = Nz + n
if zz < 0:
zz = zz + self.Nz
elif zz >= self.Nz:
zz = zz - self.Nz
temp_grid_serial[count] = zz * self.Nxy + yy * self.Nx + xx
count += 1
return temp_grid_serial, count
def Initial_Neighbor_Grid(self):
"""initial neighbour grid"""
half_cutoff = self.half_cutoff_with_skin
self.Nx = int(self.box_length[0] / half_cutoff)
self.Ny = int(self.box_length[1] / half_cutoff)
self.Nz = int(self.box_length[2] / half_cutoff)
self.grid_N = [self.Nx, self.Ny, self.Nz]
self.grid_length = [self.box_length[0] / self.Nx, self.box_length[1] / self.Ny, self.box_length[2] / self.Nz]
self.grid_length_inverse = [1.0 / self.grid_length[0], 1.0 / self.grid_length[1], 1.0 / self.grid_length[2]]
self.Nxy = self.Nx * self.Ny
self.grid_numbers = self.Nz * self.Nxy
self.atom_numbers_in_grid_bucket = [0] * self.grid_numbers
self.bucket = [-1] * (self.grid_numbers * self.max_atom_in_grid_numbers)
self.pointer = []
temp_grid_serial = [0] * 125
for i in range(self.grid_numbers):
Nz = int(i / self.Nxy)
Ny = int((i - self.Nxy * Nz) / self.Nx)
Nx = i - self.Nxy * Nz - self.Nx * Ny
count = 0
for l in range(-2, 3):
for m in range(-2, 3):
temp_grid_serial, count = self.fun(Nx, Ny, Nz, l, m, temp_grid_serial, count)
temp_grid_serial = sorted(temp_grid_serial)
self.pointer.extend(temp_grid_serial)
def NeighborListUpdate(self, crd, old_crd, uint_crd, crd_to_uint_crd_cof, uint_dr_to_dr_cof, box_length,
not_first_time=0):
"""NeighborList Update"""
self.not_first_time = not_first_time
self.neighbor_list_update = P.NeighborListUpdate(grid_numbers=self.grid_numbers, atom_numbers=self.atom_numbers,
refresh_count=self.refresh_count,
not_first_time=self.not_first_time,
Nxy=self.Nxy, excluded_atom_numbers=self.excluded_atom_numbers,
cutoff_square=self.cutoff_square,
half_skin_square=self.half_skin_square,
cutoff_with_skin=self.cutoff_with_skin,
half_cutoff_with_skin=self.half_cutoff_with_skin,
cutoff_with_skin_square=self.cutoff_with_skin_square,
refresh_interval=self.refresh_interval, cutoff=self.cutoff,
skin=self.skin,
max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
max_neighbor_numbers=self.max_neighbor_numbers)
res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket, self.bucket, crd, box_length, self.grid_N,
self.grid_length_inverse, self.atom_in_grid_serial, old_crd,
crd_to_uint_crd_cof, uint_crd, self.pointer, self.nl_atom_numbers,
self.nl_atom_serial, uint_dr_to_dr_cof, self.excluded_list_start,
self.excluded_list, self.excluded_numbers, self.need_refresh_flag)
return res

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PME"""
import math
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from mindspore.ops import operations as P
class Particle_Mesh_Ewald(nn.Cell):
"""class Particle_Mesh_Ewald"""
def __init__(self, controller, md_info):
super(Particle_Mesh_Ewald, self).__init__()
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
self.tolerance = 0.00001 if "PME_Direct_Tolerance" not in controller.Command_Set else float(
controller.Command_Set["PME_Direct_Tolerance"])
self.fftx = -1 if "fftx" not in controller.Command_Set else int(controller.Command_Set["fftx"])
self.ffty = -1 if "ffty" not in controller.Command_Set else int(controller.Command_Set["ffty"])
self.fftz = -1 if "fftz" not in controller.Command_Set else int(controller.Command_Set["fftz"])
self.atom_numbers = md_info.atom_numbers
self.box_length = md_info.box_length
if self.fftx < 0:
self.fftx = self.Get_Fft_Patameter(self.box_length[0])
if self.ffty < 0:
self.ffty = self.Get_Fft_Patameter(self.box_length[1])
if self.fftz < 0:
self.fftz = self.Get_Fft_Patameter(self.box_length[2])
self.beta = self.Get_Beta(self.cutoff, self.tolerance)
self.box_length = Tensor(np.asarray(self.box_length, np.float32), mstype.float32)
print("========== ", self.fftx, self.ffty, self.fftz, self.tolerance, self.beta)
def Get_Beta(self, cutoff, tolerance):
"""get beta"""
high = 1.0
ihigh = 1
while 1:
tempf = math.erfc(high * cutoff) / cutoff
if tempf <= tolerance:
break
high *= 2
ihigh += 1
ihigh += 50
low = 0.0
for _ in range(1, ihigh):
beta = (low + high) / 2
tempf = math.erfc(beta * cutoff) / cutoff
if tempf >= tolerance:
low = beta
else:
high = beta
return beta
def Check_2357_Factor(self, number):
"""check 2357 factor"""
while number > 0:
if number == 1:
return 1
tempn = number / 2
if tempn * 2 != number:
break
number = tempn
while number > 0:
if number == 1:
return 1
tempn = number / 3
if tempn * 3 != number:
break
number = tempn
while number > 0:
if number == 1:
return 1
tempn = number / 5
if tempn * 5 != number:
break
number = tempn
while number > 0:
if number == 1:
return 1
tempn = number / 7
if tempn * 7 != number:
break
number = tempn
return 0
def Get_Fft_Patameter(self, length):
"""get fft parameter"""
tempi = math.ceil(length + 3) >> 2 << 2
if 60 <= tempi <= 68:
tempi = 64
elif 120 <= tempi <= 136:
tempi = 128
elif 240 <= tempi <= 272:
tempi = 256
elif 480 <= tempi <= 544:
tempi = 512
elif 960 <= tempi <= 1088:
tempi = 1024
while 1:
if self.Check_2357_Factor(tempi):
return tempi
tempi += 4
def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start,
excluded_list, excluded_numbers):
"""PME_Energy"""
self.pmee = P.PMEEnergy(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \
self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof,
excluded_list_start, excluded_list, excluded_numbers)
return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy
def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list,
excluded_numbers):
"""PME Excluded Force"""
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, beta=self.beta)
self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers)
return self.frc
def PME_Reciprocal_Force(self, uint_crd, charge):
"""PME reciprocal force"""
self.pmerf = P.PMEReciprocalForce(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
self.frc = self.pmerf(self.box_length, uint_crd, charge)
return self.frc
def Energy_Device_To_Host(self):
"""Energy_Device_To_Host"""
self.ee_ene = self.reciprocal_energy + self.self_energy + self.direct_energy + self.correction_energy
return self.ee_ene

View File

@ -0,0 +1,243 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""simulation"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn
from Langevin_Liujian_md import Langevin_Liujian
from angle import Angle
from bond import Bond
from dihedral import Dihedral
from lennard_jones import Lennard_Jones_Information
from md_information import md_information
from nb14 import NON_BOND_14
from neighbor_list import nb_infomation
from particle_mesh_ewald import Particle_Mesh_Ewald
class controller:
"""class controller"""
def __init__(self, args_opt):
self.input_file = args_opt.i
self.initial_coordinates_file = args_opt.c
self.amber_parm = args_opt.amber_parm
self.restrt = args_opt.r
self.mdcrd = args_opt.x
self.mdout = args_opt.o
self.mdbox = args_opt.box
self.Command_Set = {}
self.md_task = None
self.commands_from_in_file()
def commands_from_in_file(self):
"""commands from in file"""
file = open(self.input_file, 'r')
context = file.readlines()
file.close()
self.md_task = context[0].strip()
for val in context:
if "=" in val:
assert len(val.strip().split("=")) == 2
flag, value = val.strip().split("=")
value = value.replace(",", '')
flag = flag.replace(" ", "")
if flag not in self.Command_Set:
self.Command_Set[flag] = value
else:
print("ERROR COMMAND FILE")
class Simulation(nn.Cell):
"""class simulation"""
def __init__(self, args_opt):
super(Simulation, self).__init__()
self.control = controller(args_opt)
self.md_info = md_information(self.control)
self.bond = Bond(self.control, self.md_info)
self.angle = Angle(self.control)
self.dihedral = Dihedral(self.control)
self.nb14 = NON_BOND_14(self.control, self.dihedral, self.md_info.atom_numbers)
self.nb_info = nb_infomation(self.control, self.md_info.atom_numbers, self.md_info.box_length)
self.LJ_info = Lennard_Jones_Information(self.control)
self.liujian_info = Langevin_Liujian(self.control, self.md_info.atom_numbers)
self.pme_method = Particle_Mesh_Ewald(self.control, self.md_info)
self.box_length = Tensor(np.asarray(self.md_info.box_length, np.float32), mstype.float32)
self.file = None
def Main_Before_Calculate_Force(self):
"""Main Before Calculate Force"""
_ = self.md_info.MD_Information_Crd_To_Uint_Crd()
self.md_info.uint_crd_with_LJ = (self.md_info.uint_crd, self.LJ_info.atom_LJ_type, self.md_info.charge)
return self.md_info.uint_crd, self.md_info.uint_crd_with_LJ
def Initial_Neighbor_List_Update(self, not_first_time):
"""Initial Neighbor List Update"""
res = self.nb_info.NeighborListUpdate(self.md_info.crd, self.md_info.crd_old, self.md_info.uint_crd,
self.md_info.crd_to_uint_crd_cof, self.md_info.uint_dr_to_dr_cof,
self.box_length, not_first_time)
return res
def Main_Calculate_Force(self):
"""main calculate force"""
self.bond.atom_numbers = self.md_info.atom_numbers
md_info = self.md_info
LJ_info = self.LJ_info
nb_info = self.nb_info
pme_method = self.pme_method
bond_frc, _ = self.bond.Bond_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t = bond_frc.asnumpy()
angle_frc, _ = self.angle.Angle_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t += angle_frc.asnumpy()
dihedral_frc, _ = self.dihedral.Dihedral_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
frc_t += dihedral_frc.asnumpy()
nb14_frc, _ = self.nb14.Non_Bond_14_LJ_CF_Force_With_Atom_Energy(md_info.uint_crd_with_LJ,
md_info.uint_dr_to_dr_cof, LJ_info.LJ_A,
LJ_info.LJ_B)
frc_t += nb14_frc.asnumpy()
lj_frc = LJ_info.LJ_Force_With_PME_Direct_Force(
md_info.atom_numbers, md_info.uint_crd_with_LJ, md_info.uint_dr_to_dr_cof, nb_info.nl_atom_numbers,
nb_info.nl_atom_serial, nb_info.cutoff, pme_method.beta)
frc_t += lj_frc.asnumpy()
pme_excluded_frc = pme_method.PME_Excluded_Force(
md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge,
nb_info.excluded_list_start, nb_info.excluded_list,
nb_info.excluded_numbers)
frc_t += pme_excluded_frc.asnumpy()
pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge)
frc_t += pme_reciprocal_frc.asnumpy()
self.md_info.frc = Tensor(frc_t, mstype.float32)
return self.md_info.frc
def Main_Calculate_Energy(self):
"""main calculate energy"""
_ = self.bond.Bond_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.angle.Angle_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.dihedral.Dihedral_Engergy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
_ = self.nb14.Non_Bond_14_LJ_CF_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
self.LJ_info.LJ_A,
self.LJ_info.LJ_B)
_ = self.LJ_info.LJ_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
self.nb_info.cutoff_square)
_ = self.pme_method.PME_Energy(
self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list,
self.nb_info.excluded_numbers)
_ = self.pme_method.Energy_Device_To_Host()
def Main_After_Calculate_Energy(self):
"""main after calculate energy"""
md_info = self.md_info
LJ_info = self.LJ_info
bond = self.bond
angle = self.angle
dihedral = self.dihedral
nb14 = self.nb14
pme_method = self.pme_method
md_info.total_potential_energy = 0
md_info.total_potential_energy += bond.sigma_of_bond_ene
md_info.total_potential_energy += angle.sigma_of_angle_ene
md_info.total_potential_energy += dihedral.sigma_of_dihedral_ene
md_info.total_potential_energy += nb14.nb14_lj_energy_sum + nb14.nb14_cf_energy_sum
md_info.total_potential_energy += LJ_info.LJ_energy_sum
pme_method.Energy_Device_To_Host()
md_info.total_potential_energy += pme_method.ee_ene
print("md_info.total_potential_energy", md_info.total_potential_energy)
def Main_Iteration_2(self):
"""main iteration2"""
md_info = self.md_info
control = self.control
liujian_info = self.liujian_info
if md_info.mode > 0 and int(control.Command_Set["thermostat"]) == 1:
md_info.vel, md_info.crd, md_info.frc, md_info.acc = liujian_info.MD_Iteration_Leap_Frog(
md_info.d_mass_inverse, md_info.vel, md_info.crd, md_info.frc)
def Main_After_Iteration(self):
"""main after iteration"""
md_info = self.md_info
nb_info = self.nb_info
md_info.Centerize()
_ = nb_info.NeighborListUpdate(md_info.crd, md_info.crd_old, md_info.uint_crd,
md_info.crd_to_uint_crd_cof,
md_info.uint_dr_to_dr_cof, self.box_length, not_first_time=1)
def Main_Print(self):
"""compute the temperature"""
md_info = self.md_info
temperature = md_info.MD_Information_Temperature()
md_info.h_temperature = temperature
steps = md_info.steps
temperature = temperature.asnumpy()
total_potential_energy = md_info.total_potential_energy.asnumpy()
sigma_of_bond_ene = self.bond.sigma_of_bond_ene.asnumpy()
sigma_of_angle_ene = self.angle.sigma_of_angle_ene.asnumpy()
sigma_of_dihedral_ene = self.dihedral.sigma_of_dihedral_ene.asnumpy()
nb14_lj_energy_sum = self.nb14.nb14_lj_energy_sum.asnumpy()
nb14_cf_energy_sum = self.nb14.nb14_cf_energy_sum.asnumpy()
LJ_energy_sum = self.LJ_info.LJ_energy_sum.asnumpy()
ee_ene = self.pme_method.ee_ene.asnumpy()
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), end=" ")
if self.bond.bond_numbers > 0:
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
if self.angle.angle_numbers > 0:
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
if self.dihedral.dihedral_numbers > 0:
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
if self.nb14.nb14_numbers > 0:
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
print("{:>12.3f}".format(float(ee_ene)))
if self.file is not None:
self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
return temperature
def Main_Initial(self):
"""main initial"""
if self.control.mdout:
self.file = open(self.control.mdout, 'w')
self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
def Main_Destroy(self):
"""main destroy"""
if self.file is not None:
self.file.close()
print("Save successfully!")