commit
e789642e9a
|
@ -87,7 +87,7 @@ class UncertaintyEvaluation:
|
|||
self.epi_model = deepcopy(model)
|
||||
self.ale_model = deepcopy(model)
|
||||
self.epi_train_dataset = train_dataset
|
||||
self.ale_train_dataset = train_dataset
|
||||
self.ale_train_dataset = deepcopy(train_dataset)
|
||||
self.task_type = task_type
|
||||
self.epochs = Validator.check_positive_int(epochs)
|
||||
self.epi_uncer_model_path = epi_uncer_model_path
|
||||
|
@ -101,7 +101,8 @@ class UncertaintyEvaluation:
|
|||
if not isinstance(model, Cell):
|
||||
raise TypeError('The model should be Cell type.')
|
||||
if task_type not in ('regression', 'classification'):
|
||||
raise ValueError('The task should be regression or classification.')
|
||||
raise ValueError(
|
||||
'The task should be regression or classification.')
|
||||
if task_type == 'classification':
|
||||
self.num_classes = Validator.check_positive_int(num_classes)
|
||||
else:
|
||||
|
@ -119,15 +120,19 @@ class UncertaintyEvaluation:
|
|||
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
|
||||
if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None:
|
||||
if self.task_type == 'classification':
|
||||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_loss = SoftmaxCrossEntropyWithLogits(
|
||||
sparse=True, reduction="mean")
|
||||
net_opt = Adam(self.epi_uncer_model.trainable_params())
|
||||
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
model = Model(self.epi_uncer_model, net_loss,
|
||||
net_opt, metrics={"Accuracy": Accuracy()})
|
||||
else:
|
||||
net_loss = MSELoss()
|
||||
net_opt = Adam(self.epi_uncer_model.trainable_params())
|
||||
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
||||
model = Model(self.epi_uncer_model, net_loss,
|
||||
net_opt, metrics={"MSE": MSE()})
|
||||
if self.save_model:
|
||||
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
|
||||
config_ck = CheckpointConfig(
|
||||
keep_checkpoint_max=self.epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
|
||||
directory=self.epi_uncer_model_path,
|
||||
config=config_ck)
|
||||
|
@ -137,7 +142,8 @@ class UncertaintyEvaluation:
|
|||
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
|
||||
callbacks=[LossMonitor()])
|
||||
else:
|
||||
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
|
||||
uncer_param_dict = load_checkpoint(
|
||||
self.epi_uncer_model_path)
|
||||
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
|
||||
|
||||
def _eval_epistemic_uncertainty(self, eval_data, mc=10):
|
||||
|
@ -164,15 +170,19 @@ class UncertaintyEvaluation:
|
|||
Get the model which can obtain the aleatoric uncertainty.
|
||||
"""
|
||||
if self.ale_train_dataset is None:
|
||||
raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.')
|
||||
raise ValueError(
|
||||
'The train dataset should not be None when evaluating aleatoric uncertainty.')
|
||||
if self.ale_uncer_model is None:
|
||||
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
|
||||
self.ale_uncer_model = AleatoricUncertaintyModel(
|
||||
self.ale_model, self.num_classes, self.task_type)
|
||||
net_loss = AleatoricLoss(self.task_type)
|
||||
net_opt = Adam(self.ale_uncer_model.trainable_params())
|
||||
if self.task_type == 'classification':
|
||||
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
model = Model(self.ale_uncer_model, net_loss,
|
||||
net_opt, metrics={"Accuracy": Accuracy()})
|
||||
else:
|
||||
model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
|
||||
model = Model(self.ale_uncer_model, net_loss,
|
||||
net_opt, metrics={"MSE": MSE()})
|
||||
if self.save_model:
|
||||
config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
|
||||
|
@ -284,7 +294,8 @@ class AleatoricUncertaintyModel(Cell):
|
|||
self.ale_model = ale_model
|
||||
self.var_layer = Dense(num_classes, num_classes)
|
||||
else:
|
||||
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
|
||||
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(
|
||||
ale_model)
|
||||
|
||||
def construct(self, x):
|
||||
if self.task == 'classification':
|
||||
|
@ -327,7 +338,8 @@ class AleatoricLoss(Cell):
|
|||
self.exp = P.Exp()
|
||||
self.normal = C.normal
|
||||
self.to_tensor = P.ScalarToArray()
|
||||
self.entropy = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
self.entropy = SoftmaxCrossEntropyWithLogits(
|
||||
sparse=True, reduction="mean")
|
||||
else:
|
||||
self.mean = P.ReduceMean()
|
||||
self.exp = P.Exp()
|
||||
|
@ -337,7 +349,8 @@ class AleatoricLoss(Cell):
|
|||
y_pred, var = data_pred
|
||||
if self.task == 'classification':
|
||||
sample_times = 10
|
||||
epsilon = self.normal((1, sample_times), self.to_tensor(0.0), self.to_tensor(1.0), 0)
|
||||
epsilon = self.normal((1, sample_times), self.to_tensor(
|
||||
0.0), self.to_tensor(1.0), 0)
|
||||
total_loss = 0
|
||||
for i in range(sample_times):
|
||||
y_pred_i = y_pred + epsilon[0][i] * var
|
||||
|
@ -345,5 +358,6 @@ class AleatoricLoss(Cell):
|
|||
total_loss += loss
|
||||
avg_loss = total_loss / sample_times
|
||||
return avg_loss
|
||||
loss = self.mean(0.5 * self.exp(-var) * self.pow(y - y_pred, 2) + 0.5 * var)
|
||||
loss = self.mean(0.5 * self.exp(-var) *
|
||||
self.pow(y - y_pred, 2) + 0.5 * var)
|
||||
return loss
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""main"""
|
||||
|
||||
import time
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from src.simulation_initial import Simulation
|
||||
|
||||
parser = argparse.ArgumentParser(description='Sponge Controller')
|
||||
parser.add_argument('--i', type=str, default=None, help='input file')
|
||||
parser.add_argument('--amber_parm', type=str, default=None,
|
||||
help='paramter file in AMBER type')
|
||||
parser.add_argument('--c', type=str, default=None,
|
||||
help='initial coordinates file')
|
||||
parser.add_argument('--r', type=str, default="restrt", help='')
|
||||
parser.add_argument('--x', type=str, default="mdcrd", help='')
|
||||
parser.add_argument('--o', type=str, default="mdout", help="")
|
||||
parser.add_argument('--box', type=str, default="mdbox", help='')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE,
|
||||
device_target="GPU", device_id=0, save_graphs=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = time.time()
|
||||
simulation = Simulation(args_opt)
|
||||
simulation.Main_Initial()
|
||||
res = simulation.Initial_Neighbor_List_Update(not_first_time=0)
|
||||
md_info = simulation.md_info
|
||||
md_info.step_limit = 1
|
||||
for i in range(1, md_info.step_limit + 1):
|
||||
print("steps: ", i)
|
||||
md_info.steps = i
|
||||
simulation.Main_Before_Calculate_Force()
|
||||
simulation.Main_Calculate_Force()
|
||||
simulation.Main_Calculate_Energy()
|
||||
simulation.Main_After_Calculate_Energy()
|
||||
temperature = simulation.Main_Print()
|
||||
simulation.Main_Iteration_2()
|
||||
end = time.time()
|
||||
print("Main time(s):", end - start)
|
||||
simulation.Main_Destroy()
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Langevin Liujian MD class"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class Langevin_Liujian:
|
||||
"""Langevin_Liujian class"""
|
||||
|
||||
def __init__(self, controller, atom_numbers):
|
||||
self.atom_numbers = atom_numbers
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
|
||||
self.CONSTANT_TIME_CONVERTION = 20.455
|
||||
self.CONSTANT_kB = 0.00198716
|
||||
|
||||
self.target_temperature = 300.0 if "target_temperature" not in controller.Command_Set else float(
|
||||
controller.Command_Set["target_temperature"])
|
||||
self.gamma_ln = 1.0 if "langevin_gamma" not in controller.Command_Set else float(
|
||||
controller.Command_Set["langevin_gamma"])
|
||||
self.rand_seed = 0 if "langevin_seed" not in controller.Command_Set else float(
|
||||
controller.Command_Set["langevin_seed"]) # jiahong0315
|
||||
self.max_velocity = 10000.0 if "velocity_max" not in controller.Command_Set else float(
|
||||
controller.Command_Set["velocity_max"])
|
||||
assert self.max_velocity > 0
|
||||
self.is_max_velocity = 0 if "velocity_max" not in controller.Command_Set else 1
|
||||
print("target temperature is ", self.target_temperature)
|
||||
print("friction coefficient is ", self.gamma_ln, "ps^-1")
|
||||
print("random seed is ", self.rand_seed)
|
||||
self.dt = float(controller.Command_Set["dt"])
|
||||
self.dt *= self.CONSTANT_TIME_CONVERTION
|
||||
self.half_dt = 0.5 * self.dt
|
||||
self.float4_numbers = math.ceil(3.0 * self.atom_numbers / 4.0)
|
||||
self.gamma_ln = self.gamma_ln / self.CONSTANT_TIME_CONVERTION
|
||||
self.exp_gamma = math.exp(-1 * self.gamma_ln * self.dt)
|
||||
self.sqrt_gamma = math.sqrt((1. - self.exp_gamma * self.exp_gamma) * self.target_temperature * self.CONSTANT_kB)
|
||||
self.h_sqrt_mass = [0] * self.atom_numbers
|
||||
for i in range(self.atom_numbers):
|
||||
self.h_sqrt_mass[i] = self.sqrt_gamma * math.sqrt(1. / self.h_mass[i])
|
||||
self.d_sqrt_mass = Tensor(self.h_sqrt_mass, mstype.float32)
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.h_mass = [0] * self.atom_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG MASS" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.atom_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
|
||||
for i in range(self.atom_numbers):
|
||||
self.h_mass[i] = information[i]
|
||||
break
|
||||
|
||||
def MDIterationLeapFrog_Liujian(self, atom_numbers, half_dt, dt, exp_gamma, inverse_mass, sqrt_mass_inverse, vel,
|
||||
crd, frc, random_frc):
|
||||
"""compute MDIterationLeapFrog Liujian"""
|
||||
inverse_mass = inverse_mass.reshape((-1, 1))
|
||||
sqrt_mass_inverse = sqrt_mass_inverse.reshape((-1, 1))
|
||||
acc = inverse_mass * frc
|
||||
vel = vel + dt * acc
|
||||
crd = crd + half_dt * vel
|
||||
vel = exp_gamma * vel + sqrt_mass_inverse * random_frc
|
||||
crd = crd + half_dt * vel
|
||||
frc = Tensor(np.zeros((atom_numbers, 3)), mstype.float32)
|
||||
return vel, crd, frc, acc
|
||||
|
||||
def MD_Iteration_Leap_Frog(self, d_mass_inverse, vel_in, crd_in, frc_in):
|
||||
"""MD_Iteration_Leap_Frog"""
|
||||
np.random.seed(int(self.rand_seed))
|
||||
self.rand_force = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
|
||||
# self.rand_force = Tensor(np.random.randn(self.atom_numbers, 3), mstype.float32)
|
||||
vel, crd, frc, acc = self.MDIterationLeapFrog_Liujian(atom_numbers=self.atom_numbers, half_dt=self.half_dt,
|
||||
dt=self.dt, exp_gamma=self.exp_gamma,
|
||||
inverse_mass=d_mass_inverse,
|
||||
sqrt_mass_inverse=self.d_sqrt_mass,
|
||||
vel=vel_in, crd=crd_in,
|
||||
frc=frc_in, random_frc=self.rand_force)
|
||||
return vel, crd, frc, acc
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""angle class"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Angle(nn.Cell):
|
||||
"""Angle class"""
|
||||
|
||||
def __init__(self, controller):
|
||||
super(Angle, self).__init__()
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
|
||||
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
|
||||
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
|
||||
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
|
||||
self.angle_k = Tensor(np.asarray(self.h_angle_k, np.float32), mstype.float32)
|
||||
self.angle_theta0 = Tensor(np.asarray(self.h_angle_theta0, np.float32), mstype.float32)
|
||||
|
||||
def read_process1(self, context):
|
||||
"""read_information_from_amberfile process1"""
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
count = 0
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.angle_with_H_numbers = value[4]
|
||||
self.angle_without_H_numbers = value[5]
|
||||
self.angle_numbers = self.angle_with_H_numbers + self.angle_without_H_numbers
|
||||
information = []
|
||||
information.extend(value)
|
||||
while count < 15:
|
||||
start_idx += 1
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.angle_type_numbers = information[16]
|
||||
print("angle type numbers ", self.angle_type_numbers)
|
||||
break
|
||||
|
||||
def read_process2(self, context):
|
||||
"""read_information_from_amberfile process2"""
|
||||
angle_count = 0
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG ANGLES_INC_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 4 * self.angle_with_H_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for _ in range(self.angle_with_H_numbers):
|
||||
self.h_atom_a[angle_count] = information[angle_count * 4 + 0] / 3
|
||||
self.h_atom_b[angle_count] = information[angle_count * 4 + 1] / 3
|
||||
self.h_atom_c[angle_count] = information[angle_count * 4 + 2] / 3
|
||||
self.h_type[angle_count] = information[angle_count * 4 + 3] - 1
|
||||
angle_count += 1
|
||||
|
||||
break
|
||||
return angle_count
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.read_process1(context)
|
||||
|
||||
self.h_atom_a = [0] * self.angle_numbers
|
||||
self.h_atom_b = [0] * self.angle_numbers
|
||||
self.h_atom_c = [0] * self.angle_numbers
|
||||
self.h_type = [0] * self.angle_numbers
|
||||
angle_count = self.read_process2(context)
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG ANGLES_WITHOUT_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 4 * self.angle_without_H_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.angle_without_H_numbers):
|
||||
self.h_atom_a[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 0] / 3
|
||||
self.h_atom_b[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 1] / 3
|
||||
self.h_atom_c[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 2] / 3
|
||||
self.h_type[angle_count] = information[(angle_count - self.angle_with_H_numbers) * 4 + 3] - 1
|
||||
angle_count += 1
|
||||
break
|
||||
|
||||
self.type_k = [0] * self.angle_type_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG ANGLE_FORCE_CONSTANT" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.angle_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
# print(start_idx)
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.type_k = information[:self.angle_type_numbers]
|
||||
break
|
||||
|
||||
self.type_theta0 = [0] * self.angle_type_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG ANGLE_EQUIL_VALUE" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.angle_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.type_theta0 = information[:self.angle_type_numbers]
|
||||
break
|
||||
if self.angle_numbers != angle_count:
|
||||
print("angle count %d != angle_number %d ", angle_count, self.angle_numbers)
|
||||
|
||||
self.h_angle_k = []
|
||||
self.h_angle_theta0 = []
|
||||
for i in range(self.angle_numbers):
|
||||
self.h_angle_k.append(self.type_k[self.h_type[i]])
|
||||
self.h_angle_theta0.append(self.type_theta0[self.h_type[i]])
|
||||
|
||||
def Angle_Energy(self, uint_crd, uint_dr_to_dr_cof):
|
||||
"""compute angle energy"""
|
||||
self.angle_energy = P.AngleEnergy(self.angle_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a, self.atom_b,
|
||||
self.atom_c, self.angle_k, self.angle_theta0)
|
||||
self.sigma_of_angle_ene = P.ReduceSum()(self.angle_energy)
|
||||
return self.sigma_of_angle_ene
|
||||
|
||||
def Angle_Force_With_Atom_Energy(self, uint_crd, scaler):
|
||||
"""compute angle force with atom energy"""
|
||||
print("angele angle numbers:", self.angle_numbers)
|
||||
self.afae = P.AngleForceWithAtomEnergy(angle_numbers=self.angle_numbers)
|
||||
frc, ene = self.afae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.angle_k, self.angle_theta0)
|
||||
return frc, ene
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""bond class"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Bond(nn.Cell):
|
||||
"""bond class"""
|
||||
|
||||
def __init__(self, controller, md_info):
|
||||
super(Bond, self).__init__()
|
||||
|
||||
self.atom_numbers = md_info.atom_numbers
|
||||
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
|
||||
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
|
||||
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
|
||||
self.bond_k = Tensor(np.asarray(self.h_k, np.float32), mstype.float32)
|
||||
self.bond_r0 = Tensor(np.asarray(self.h_r0, np.float32), mstype.float32)
|
||||
|
||||
def process1(self, context):
|
||||
"""process1: read information from amberfile"""
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
count = 0
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.bond_with_hydrogen = value[2]
|
||||
self.bond_numbers = value[3]
|
||||
self.bond_numbers += self.bond_with_hydrogen
|
||||
print(self.bond_numbers)
|
||||
information = []
|
||||
information.extend(value)
|
||||
while count < 16:
|
||||
start_idx += 1
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.bond_type_numbers = information[15]
|
||||
print("bond type numbers ", self.bond_type_numbers)
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG BOND_FORCE_CONSTANT" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.bond_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.bond_type_k = information[:self.bond_type_numbers]
|
||||
break
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.process1(context)
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG BOND_EQUIL_VALUE" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.bond_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.bond_type_r = information[:self.bond_type_numbers]
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG BONDS_INC_HYDROGEN" in val:
|
||||
self.h_atom_a = [0] * self.bond_numbers
|
||||
self.h_atom_b = [0] * self.bond_numbers
|
||||
self.h_k = [0] * self.bond_numbers
|
||||
self.h_r0 = [0] * self.bond_numbers
|
||||
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 3 * self.bond_with_hydrogen:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
|
||||
for i in range(self.bond_with_hydrogen):
|
||||
self.h_atom_a[i] = information[3 * i + 0] / 3
|
||||
self.h_atom_b[i] = information[3 * i + 1] / 3
|
||||
tmpi = information[3 * i + 2] - 1
|
||||
self.h_k[i] = self.bond_type_k[tmpi]
|
||||
self.h_r0[i] = self.bond_type_r[tmpi]
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG BONDS_WITHOUT_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 3 * (self.bond_numbers - self.bond_with_hydrogen):
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
|
||||
for i in range(self.bond_with_hydrogen, self.bond_numbers):
|
||||
self.h_atom_a[i] = information[3 * (i - self.bond_with_hydrogen) + 0] / 3
|
||||
self.h_atom_b[i] = information[3 * (i - self.bond_with_hydrogen) + 1] / 3
|
||||
tmpi = information[3 * (i - self.bond_with_hydrogen) + 2] - 1
|
||||
self.h_k[i] = self.bond_type_k[tmpi]
|
||||
self.h_r0[i] = self.bond_type_r[tmpi]
|
||||
break
|
||||
|
||||
def Bond_Energy(self, uint_crd, uint_dr_to_dr_cof):
|
||||
"""compute bond energy"""
|
||||
self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
|
||||
self.atom_b, self.bond_k, self.bond_r0)
|
||||
self.sigma_of_bond_ene = P.ReduceSum()(self.bond_energy)
|
||||
return self.sigma_of_bond_ene
|
||||
|
||||
def Bond_Force_With_Atom_Energy(self, uint_crd, scaler):
|
||||
"""compute bond force with atom energy"""
|
||||
self.bfatomenergy = P.BondForceWithAtomEnergy(bond_numbers=self.bond_numbers,
|
||||
atom_numbers=self.atom_numbers)
|
||||
frc, atom_energy = self.bfatomenergy(uint_crd, scaler, self.atom_a, self.atom_b, self.bond_k, self.bond_r0)
|
||||
return frc, atom_energy
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dihedral class"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Dihedral(nn.Cell):
|
||||
"""dihedral class"""
|
||||
|
||||
def __init__(self, controller):
|
||||
super(Dihedral, self).__init__()
|
||||
self.CONSTANT_Pi = 3.1415926535897932
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
|
||||
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
|
||||
self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
|
||||
self.atom_d = Tensor(np.asarray(self.h_atom_d, np.int32), mstype.int32)
|
||||
|
||||
self.pk = Tensor(np.asarray(self.pk, np.float32), mstype.float32)
|
||||
self.gamc = Tensor(np.asarray(self.gamc, np.float32), mstype.float32)
|
||||
self.gams = Tensor(np.asarray(self.gams, np.float32), mstype.float32)
|
||||
self.pn = Tensor(np.asarray(self.pn, np.float32), mstype.float32)
|
||||
self.ipn = Tensor(np.asarray(self.ipn, np.int32), mstype.int32)
|
||||
|
||||
def process1(self, context):
|
||||
"""process1: read information from amberfile"""
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
count = 0
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.dihedral_with_hydrogen = value[6]
|
||||
self.dihedral_numbers = value[7]
|
||||
self.dihedral_numbers += self.dihedral_with_hydrogen
|
||||
information = []
|
||||
information.extend(value)
|
||||
while count < 15:
|
||||
start_idx += 1
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.dihedral_type_numbers = information[17]
|
||||
print("dihedral type numbers ", self.dihedral_type_numbers)
|
||||
break
|
||||
|
||||
self.phase_type = [0] * self.dihedral_type_numbers
|
||||
self.pk_type = [0] * self.dihedral_type_numbers
|
||||
self.pn_type = [0] * self.dihedral_type_numbers
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRAL_FORCE_CONSTANT" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.dihedral_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.pk_type = information[:self.dihedral_type_numbers]
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRAL_PHASE" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.dihedral_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.phase_type = information[:self.dihedral_type_numbers]
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRAL_PERIODICITY" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.dihedral_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.pn_type = information[:self.dihedral_type_numbers]
|
||||
break
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
|
||||
self.process1(context)
|
||||
|
||||
self.h_atom_a = [0] * self.dihedral_numbers
|
||||
self.h_atom_b = [0] * self.dihedral_numbers
|
||||
self.h_atom_c = [0] * self.dihedral_numbers
|
||||
self.h_atom_d = [0] * self.dihedral_numbers
|
||||
self.pk = []
|
||||
self.gamc = []
|
||||
self.gams = []
|
||||
self.pn = []
|
||||
self.ipn = []
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRALS_INC_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 5 * self.dihedral_with_hydrogen:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.dihedral_with_hydrogen):
|
||||
self.h_atom_a[i] = information[i * 5 + 0] / 3
|
||||
self.h_atom_b[i] = information[i * 5 + 1] / 3
|
||||
self.h_atom_c[i] = information[i * 5 + 2] / 3
|
||||
self.h_atom_d[i] = abs(information[i * 5 + 3] / 3)
|
||||
tmpi = information[i * 5 + 4] - 1
|
||||
self.pk.append(self.pk_type[tmpi])
|
||||
tmpf = self.phase_type[tmpi]
|
||||
if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
|
||||
tmpf = self.CONSTANT_Pi
|
||||
tmpf2 = math.cos(tmpf)
|
||||
if abs(tmpf2) < 1e-6:
|
||||
tmpf2 = 0
|
||||
self.gamc.append(tmpf2 * self.pk[i])
|
||||
tmpf2 = math.sin(tmpf)
|
||||
if abs(tmpf2) < 1e-6:
|
||||
tmpf2 = 0
|
||||
self.gams.append(tmpf2 * self.pk[i])
|
||||
self.pn.append(abs(self.pn_type[tmpi]))
|
||||
self.ipn.append(int(self.pn[i] + 0.001))
|
||||
break
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRALS_WITHOUT_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 5 * (self.dihedral_numbers - self.dihedral_with_hydrogen):
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.dihedral_with_hydrogen, self.dihedral_numbers):
|
||||
self.h_atom_a[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 0] / 3
|
||||
self.h_atom_b[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 1] / 3
|
||||
self.h_atom_c[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 2] / 3
|
||||
self.h_atom_d[i] = abs(information[(i - self.dihedral_with_hydrogen) * 5 + 3] / 3)
|
||||
tmpi = information[(i - self.dihedral_with_hydrogen) * 5 + 4] - 1
|
||||
self.pk.append(self.pk_type[tmpi])
|
||||
tmpf = self.phase_type[tmpi]
|
||||
if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
|
||||
tmpf = self.CONSTANT_Pi
|
||||
tmpf2 = math.cos(tmpf)
|
||||
if abs(tmpf2) < 1e-6:
|
||||
tmpf2 = 0
|
||||
self.gamc.append(tmpf2 * self.pk[i])
|
||||
tmpf2 = math.sin(tmpf)
|
||||
if abs(tmpf2) < 1e-6:
|
||||
tmpf2 = 0
|
||||
self.gams.append(tmpf2 * self.pk[i])
|
||||
self.pn.append(abs(self.pn_type[tmpi]))
|
||||
self.ipn.append(int(self.pn[i] + 0.001))
|
||||
break
|
||||
for i in range(self.dihedral_numbers):
|
||||
if self.h_atom_c[i] < 0:
|
||||
self.h_atom_c[i] *= -1
|
||||
|
||||
def Dihedral_Engergy(self, uint_crd, uint_dr_to_dr_cof):
|
||||
"""compute dihedral energy"""
|
||||
self.dihedral_energy = P.DihedralEnergy(self.dihedral_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
|
||||
self.atom_b, self.atom_c, self.atom_d, self.ipn,
|
||||
self.pk, self.gamc, self.gams, self.pn)
|
||||
self.sigma_of_dihedral_ene = P.ReduceSum()(self.dihedral_energy)
|
||||
return self.sigma_of_dihedral_ene
|
||||
|
||||
def Dihedral_Force_With_Atom_Energy(self, uint_crd, scaler):
|
||||
"""compute dihedral force and atom energy"""
|
||||
self.dfae = P.DihedralForceWithAtomEnergy(dihedral_numbers=self.dihedral_numbers)
|
||||
self.frc, self.ene = self.dfae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.atom_d,
|
||||
self.ipn, self.pk, self.gamc, self.gams, self.pn)
|
||||
return self.frc, self.ene
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lennard jones"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Lennard_Jones_Information(nn.Cell):
|
||||
"""class Lennard Jones Information"""
|
||||
|
||||
def __init__(self, controller):
|
||||
super(Lennard_Jones_Information, self).__init__()
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
self.atom_LJ_type = Tensor(np.asarray(self.atom_LJ_type, dtype=np.int32), mstype.int32)
|
||||
self.LJ_A = Tensor(np.asarray(self.LJ_A, dtype=np.float32), mstype.float32)
|
||||
self.LJ_B = Tensor(np.asarray(self.LJ_B, dtype=np.float32), mstype.float32)
|
||||
self.LJ_energy_sum = 0
|
||||
self.LJ_energy = 0
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
count = 0
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.atom_numbers = value[0]
|
||||
self.atom_type_numbers = value[1]
|
||||
self.pair_type_numbers = int(self.atom_type_numbers * (self.atom_type_numbers + 1) / 2)
|
||||
print(self.pair_type_numbers)
|
||||
break
|
||||
self.atom_LJ_type = [0] * self.atom_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG ATOM_TYPE_INDEX" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.atom_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.atom_numbers):
|
||||
self.atom_LJ_type[i] = information[i] - 1
|
||||
break
|
||||
self.LJ_A = [0] * self.pair_type_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG LENNARD_JONES_ACOEF" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.pair_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.pair_type_numbers):
|
||||
self.LJ_A[i] = 12.0 * information[i]
|
||||
break
|
||||
self.LJ_B = [0] * self.pair_type_numbers
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG LENNARD_JONES_BCOEF" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.pair_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.pair_type_numbers):
|
||||
self.LJ_B[i] = 6.0 * information[i]
|
||||
break
|
||||
|
||||
def LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, cutoff_square):
|
||||
"""compute LJ energy"""
|
||||
uint_crd, LJtype, charge = uint_crd_with_LJ
|
||||
self.LJ_energy = P.LJEnergy(self.atom_numbers, cutoff_square) \
|
||||
(uint_crd, LJtype, charge, uint_dr_to_dr_cof, nl_atom_numbers, nl_atom_serial, self.LJ_A, self.LJ_B)
|
||||
self.LJ_energy_sum = P.ReduceSum()(self.LJ_energy)
|
||||
return self.LJ_energy_sum
|
||||
|
||||
def LJ_Force_With_PME_Direct_Force(self, atom_numbers, uint_crd_with_LJ, uint_dr_to_dr_cof, nl_number, nl_serial,
|
||||
cutoff, beta):
|
||||
"""compute LJ force with PME direct force"""
|
||||
assert atom_numbers == self.atom_numbers
|
||||
assert isinstance(uint_crd_with_LJ, tuple)
|
||||
uint_crd_f, LJtype, charge = uint_crd_with_LJ
|
||||
self.ljfd = P.LJForceWithPMEDirectForce(atom_numbers, cutoff, beta)
|
||||
frc = self.ljfd(uint_crd_f, LJtype, charge, uint_dr_to_dr_cof, nl_number, nl_serial, self.LJ_A, self.LJ_B)
|
||||
return frc
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""md information"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class md_information(nn.Cell):
|
||||
"""class md information"""
|
||||
|
||||
def __init__(self, controller):
|
||||
super(md_information, self).__init__()
|
||||
CONSTANT_TIME_CONVERTION = 20.455
|
||||
CONSTANT_UINT_MAX_FLOAT = 4294967296.0
|
||||
self.md_task = controller.md_task
|
||||
self.mode = 0 if "mode" not in controller.Command_Set else int(controller.Command_Set["mode"])
|
||||
self.dt = 0.001 * CONSTANT_TIME_CONVERTION if "dt" not in controller.Command_Set \
|
||||
else float(controller.Command_Set["dt"]) * CONSTANT_TIME_CONVERTION
|
||||
self.skin = 2.0 if "skin" not in controller.Command_Set \
|
||||
else float(controller.Command_Set["skin"])
|
||||
self.trans_vec = [self.skin, self.skin, self.skin]
|
||||
self.trans_vec_minus = -1 * self.trans_vec
|
||||
self.step_limit = 1000 if "step_limit" not in controller.Command_Set else int(
|
||||
controller.Command_Set["step_limit"])
|
||||
self.netfrc = 0 if "net_force" not in controller.Command_Set else int(controller.Command_Set["net_force"])
|
||||
self.ntwx = 1000 if "write_information_interval" not in controller.Command_Set else \
|
||||
int(controller.Command_Set["write_information_interval"])
|
||||
self.ntce = self.step_limit + 1 if "calculate_energy_interval" not in controller.Command_Set else \
|
||||
int(controller.Command_Set["calculate_energy_interval"])
|
||||
self.atom_numbers = 0
|
||||
self.residue_numbers = 0
|
||||
self.density = 0.0
|
||||
self.lin_serial = []
|
||||
self.h_res_start = []
|
||||
self.h_res_end = []
|
||||
self.h_mass = []
|
||||
self.h_mass_inverse = []
|
||||
self.h_charge = []
|
||||
self.steps = 0
|
||||
|
||||
if controller.amber_parm is not None:
|
||||
self.read_basic_system_information_from_amber_file(controller.amber_parm)
|
||||
|
||||
if "amber_irest" in controller.Command_Set:
|
||||
amber_irest = int(controller.Command_Set["amber_irest"])
|
||||
if controller.initial_coordinates_file is not None:
|
||||
self.read_basic_system_information_from_rst7(controller.initial_coordinates_file, amber_irest)
|
||||
|
||||
self.crd_to_uint_crd_cof = [CONSTANT_UINT_MAX_FLOAT / self.box_length[0],
|
||||
CONSTANT_UINT_MAX_FLOAT / self.box_length[1],
|
||||
CONSTANT_UINT_MAX_FLOAT / self.box_length[2]]
|
||||
self.uint_dr_to_dr_cof = [1.0 / self.crd_to_uint_crd_cof[0], 1.0 / self.crd_to_uint_crd_cof[1],
|
||||
1.0 / self.crd_to_uint_crd_cof[2]]
|
||||
self.density *= 1e24 / 6.023e23 / (self.box_length[0] * self.box_length[1] * self.box_length[2])
|
||||
self.frc = Tensor(np.zeros((self.atom_numbers, 3)), mstype.float32)
|
||||
self.crd = Tensor(np.array(self.coordinate, dtype=np.float32).reshape((self.atom_numbers, 3)), mstype.float32)
|
||||
self.crd_n = np.array(self.coordinate).reshape([self.atom_numbers, 3])
|
||||
self.crd_old = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32)
|
||||
self.uint_crd = Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32)
|
||||
self.charge = Tensor(self.h_charge, mstype.float32)
|
||||
self.crd_to_uint_crd_cof_n = np.array(self.crd_to_uint_crd_cof)
|
||||
self.crd_to_uint_crd_cof = Tensor(self.crd_to_uint_crd_cof, mstype.float32)
|
||||
self.uint_dr_to_dr_cof = Tensor(self.uint_dr_to_dr_cof, mstype.float32)
|
||||
self.uint_crd_with_LJ = None
|
||||
self.d_mass_inverse = Tensor(self.h_mass_inverse, mstype.float32)
|
||||
self.d_res_start = Tensor(self.h_res_start, mstype.int32)
|
||||
self.d_res_end = Tensor(self.h_res_end, mstype.int32)
|
||||
self.d_mass = Tensor(self.h_mass, mstype.float32)
|
||||
|
||||
def process1(self, context):
|
||||
"""process1: read basic system information from amber file"""
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.atom_numbers = value[0]
|
||||
count = len(value) - 1
|
||||
while count < 10:
|
||||
start_idx += 1
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
count += len(value)
|
||||
self.residue_numbers = list(map(int, context[start_idx].strip().split()))[10 - (count - 10)]
|
||||
break
|
||||
|
||||
def read_basic_system_information_from_amber_file(self, path):
|
||||
"""read basic system information from amber file"""
|
||||
file = open(path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.process1(context)
|
||||
|
||||
if self.residue_numbers != 0 and self.atom_numbers != 0:
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG RESIDUE_POINTER" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
while count != self.residue_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
self.lin_serial.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.residue_numbers - 1):
|
||||
self.h_res_start.append(self.lin_serial[i] - 1)
|
||||
self.h_res_end.append(self.lin_serial[i + 1] - 1)
|
||||
self.h_res_start.append(self.lin_serial[-1] - 1)
|
||||
self.h_res_end.append(self.atom_numbers + 1 - 1)
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG MASS" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
while count != self.atom_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
self.h_mass.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.atom_numbers):
|
||||
if self.h_mass[i] == 0:
|
||||
self.h_mass_inverse.append(0.0)
|
||||
else:
|
||||
self.h_mass_inverse.append(1.0 / self.h_mass[i])
|
||||
self.density += self.h_mass[i]
|
||||
break
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG CHARGE" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
while count != self.atom_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
self.h_charge.extend(value)
|
||||
count += len(value)
|
||||
break
|
||||
|
||||
def read_basic_system_information_from_rst7(self, path, irest):
|
||||
"""read basic system information from rst7"""
|
||||
file = open(path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
atom_numbers = int(context[1].strip().split()[0])
|
||||
if atom_numbers != self.atom_numbers:
|
||||
print("ERROR")
|
||||
else:
|
||||
print("check atom_numbers")
|
||||
information = []
|
||||
count = 0
|
||||
start_idx = 1
|
||||
if irest == 1:
|
||||
self.simulation_start_time = float(context[1].strip().split()[1])
|
||||
while count <= 6 * self.atom_numbers + 3:
|
||||
start_idx += 1
|
||||
# print(start_idx)
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.coordinate = information[: 3 * self.atom_numbers]
|
||||
self.velocity = information[3 * self.atom_numbers: 6 * self.atom_numbers]
|
||||
self.box_length = information[6 * self.atom_numbers:6 * self.atom_numbers + 3]
|
||||
else:
|
||||
while count <= 3 * self.atom_numbers + 3:
|
||||
start_idx += 1
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.coordinate = information[: 3 * self.atom_numbers]
|
||||
self.velocity = [0.0] * (3 * self.atom_numbers)
|
||||
self.box_length = information[3 * self.atom_numbers:3 * self.atom_numbers + 3]
|
||||
self.vel = Tensor(self.velocity, mstype.float32).reshape((self.atom_numbers, 3))
|
||||
self.acc = Tensor(np.zeros((self.atom_numbers, 3), dtype=np.float32), mstype.float32)
|
||||
|
||||
def MD_Information_Crd_To_Uint_Crd(self):
|
||||
"""transform the crd to uint crd"""
|
||||
uint_crd = self.crd.asnumpy() * (0.5 * self.crd_to_uint_crd_cof.asnumpy()) * 2
|
||||
self.uint_crd = Tensor(uint_crd, mstype.uint32)
|
||||
|
||||
return self.uint_crd
|
||||
|
||||
def Centerize(self):
|
||||
return
|
||||
|
||||
def MD_Information_Temperature(self):
|
||||
"""compute temperature"""
|
||||
self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers)
|
||||
self.res_ek_energy = self.mdtemp(self.d_res_start, self.d_res_end, self.vel, self.d_mass)
|
||||
self.d_temperature = P.ReduceSum()(self.res_ek_energy)
|
||||
return self.d_temperature
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""nb14"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NON_BOND_14(nn.Cell):
|
||||
"""class Non bond 14"""
|
||||
|
||||
def __init__(self, controller, dihedral, atom_numbers):
|
||||
super(NON_BOND_14, self).__init__()
|
||||
|
||||
self.dihedral_with_hydrogen = dihedral.dihedral_with_hydrogen
|
||||
self.dihedral_numbers = dihedral.dihedral_numbers
|
||||
self.dihedral_type_numbers = dihedral.dihedral_type_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
|
||||
self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
|
||||
self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
|
||||
self.lj_scale_factor = Tensor(np.asarray(self.h_lj_scale_factor, np.float32), mstype.float32)
|
||||
self.cf_scale_factor = Tensor(np.asarray(self.h_cf_scale_factor, np.float32), mstype.float32)
|
||||
|
||||
def process1(self, context):
|
||||
"""process1: read information from amberfile"""
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG SCEE_SCALE_FACTOR" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.dihedral_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.cf_scale_type = information[:self.dihedral_type_numbers]
|
||||
break
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG SCNB_SCALE_FACTOR" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.dihedral_type_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(float, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.lj_scale_type = information[:self.dihedral_type_numbers]
|
||||
break
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
|
||||
self.cf_scale_type = [0] * self.dihedral_type_numbers
|
||||
self.lj_scale_type = [0] * self.dihedral_type_numbers
|
||||
|
||||
self.h_atom_a = [0] * self.dihedral_numbers
|
||||
self.h_atom_b = [0] * self.dihedral_numbers
|
||||
self.h_lj_scale_factor = [0] * self.dihedral_numbers
|
||||
self.h_cf_scale_factor = [0] * self.dihedral_numbers
|
||||
nb14_numbers = 0
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRALS_INC_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 5 * self.dihedral_with_hydrogen:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.dihedral_with_hydrogen):
|
||||
tempa = information[i * 5 + 0]
|
||||
tempi = information[i * 5 + 1]
|
||||
tempi2 = information[i * 5 + 2]
|
||||
tempb = information[i * 5 + 3]
|
||||
tempi = information[i * 5 + 4]
|
||||
|
||||
tempi -= 1
|
||||
if tempi2 > 0:
|
||||
self.h_atom_a[nb14_numbers] = tempa / 3
|
||||
self.h_atom_b[nb14_numbers] = abs(tempb / 3)
|
||||
self.h_lj_scale_factor[nb14_numbers] = self.lj_scale_type[tempi]
|
||||
if self.h_lj_scale_factor[nb14_numbers] != 0:
|
||||
self.h_lj_scale_factor[nb14_numbers] = 1.0 / self.h_lj_scale_factor[nb14_numbers]
|
||||
self.h_cf_scale_factor[nb14_numbers] = self.cf_scale_type[tempi]
|
||||
if self.h_cf_scale_factor[nb14_numbers] != 0:
|
||||
self.h_cf_scale_factor[nb14_numbers] = 1.0 / self.h_cf_scale_factor[nb14_numbers]
|
||||
nb14_numbers += 1
|
||||
break
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG DIHEDRALS_WITHOUT_HYDROGEN" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < 5 * (self.dihedral_numbers - self.dihedral_with_hydrogen):
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
for i in range(self.dihedral_with_hydrogen, self.dihedral_numbers):
|
||||
tempa = information[(i - self.dihedral_with_hydrogen) * 5 + 0]
|
||||
tempi = information[(i - self.dihedral_with_hydrogen) * 5 + 1]
|
||||
tempi2 = information[(i - self.dihedral_with_hydrogen) * 5 + 2]
|
||||
tempb = information[(i - self.dihedral_with_hydrogen) * 5 + 3]
|
||||
tempi = information[(i - self.dihedral_with_hydrogen) * 5 + 4]
|
||||
|
||||
tempi -= 1
|
||||
if tempi2 > 0:
|
||||
self.h_atom_a[nb14_numbers] = tempa / 3
|
||||
self.h_atom_b[nb14_numbers] = abs(tempb / 3)
|
||||
self.h_lj_scale_factor[nb14_numbers] = self.lj_scale_type[tempi]
|
||||
if self.h_lj_scale_factor[nb14_numbers] != 0:
|
||||
self.h_lj_scale_factor[nb14_numbers] = 1.0 / self.h_lj_scale_factor[nb14_numbers]
|
||||
self.h_cf_scale_factor[nb14_numbers] = self.cf_scale_type[tempi]
|
||||
if self.h_cf_scale_factor[nb14_numbers] != 0:
|
||||
self.h_cf_scale_factor[nb14_numbers] = 1.0 / self.h_cf_scale_factor[nb14_numbers]
|
||||
nb14_numbers += 1
|
||||
break
|
||||
|
||||
self.nb14_numbers = nb14_numbers
|
||||
|
||||
def Non_Bond_14_LJ_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
|
||||
"""compute Non bond 14 LJ energy"""
|
||||
assert isinstance(uint_crd_with_LJ, tuple)
|
||||
uint_crd, LJtype, charge = uint_crd_with_LJ
|
||||
self.LJ_energy = P.Dihedral14LJEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
|
||||
uint_dr_to_dr_cof, self.atom_a,
|
||||
self.atom_b, self.lj_scale_factor,
|
||||
LJ_A, LJ_B)
|
||||
self.nb14_lj_energy_sum = P.ReduceSum()(self.LJ_energy)
|
||||
|
||||
return self.nb14_lj_energy_sum
|
||||
|
||||
def Non_Bond_14_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof):
|
||||
"""compute Non bond 14 CF energy"""
|
||||
assert isinstance(uint_crd_with_LJ, tuple)
|
||||
uint_crd, LJtype, charge = uint_crd_with_LJ
|
||||
self.CF_energy = P.Dihedral14CFEnergy(self.nb14_numbers, self.atom_numbers)(uint_crd, LJtype, charge,
|
||||
uint_dr_to_dr_cof, self.atom_a,
|
||||
self.atom_b, self.cf_scale_factor)
|
||||
self.nb14_cf_energy_sum = P.ReduceSum()(self.CF_energy)
|
||||
return self.nb14_cf_energy_sum
|
||||
|
||||
def Non_Bond_14_LJ_CF_Energy(self, uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B):
|
||||
"""compute Non bond 14 LJ and CF energy"""
|
||||
assert isinstance(uint_crd_with_LJ, tuple)
|
||||
self.nb14_lj_energy_sum = self.Non_Bond_14_LJ_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof, LJ_A, LJ_B)
|
||||
self.nb14_cf_energy_sum = self.Non_Bond_14_CF_Energy(uint_crd_with_LJ, uint_dr_to_dr_cof)
|
||||
|
||||
return self.nb14_lj_energy_sum, self.nb14_cf_energy_sum
|
||||
|
||||
def Non_Bond_14_LJ_CF_Force_With_Atom_Energy(self, uint_crd_with_LJ, boxlength, LJ_A, LJ_B):
|
||||
"""compute Non bond 14 LJ CF force and atom energy"""
|
||||
self.d14lj = P.Dihedral14LJCFForceWithAtomEnergy(nb14_numbers=self.nb14_numbers, atom_numbers=self.atom_numbers)
|
||||
assert isinstance(uint_crd_with_LJ, tuple)
|
||||
uint_crd_f, LJtype, charge = uint_crd_with_LJ
|
||||
self.frc, self.atom_ene = self.d14lj(uint_crd_f, LJtype, charge, boxlength, self.atom_a, self.atom_b,
|
||||
self.lj_scale_factor, self.cf_scale_factor, LJ_A, LJ_B)
|
||||
return self.frc, self.atom_ene
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""neighbour list"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class nb_infomation(nn.Cell):
|
||||
"""neighbour list"""
|
||||
|
||||
def __init__(self, controller, atom_numbers, box_length):
|
||||
super(nb_infomation, self).__init__()
|
||||
self.refresh_interval = 20 if "neighbor_list_refresh_interval" not in controller.Command_Set else \
|
||||
int(controller.Command_Set["neighbor_list_refresh_interval"])
|
||||
self.max_atom_in_grid_numbers = 64 if "max_atom_in_grid_numbers" not in controller.Command_Set else \
|
||||
int(controller.Command_Set["max_atom_in_grid_numbers"])
|
||||
self.max_neighbor_numbers = 800 if "max_neighbor_numbers" not in controller.Command_Set else \
|
||||
int(controller.Command_Set["max_neighbor_numbers"])
|
||||
self.skin = 2.0 if "skin" not in controller.Command_Set else float(controller.Command_Set["skin"])
|
||||
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
|
||||
self.cutoff_square = self.cutoff * self.cutoff
|
||||
self.cutoff_with_skin = self.cutoff + self.skin
|
||||
self.half_cutoff_with_skin = 0.5 * self.cutoff_with_skin
|
||||
self.cutoff_with_skin_square = self.cutoff_with_skin * self.cutoff_with_skin
|
||||
self.half_skin_square = 0.25 * self.skin * self.skin
|
||||
self.atom_numbers = atom_numbers
|
||||
self.box_length = box_length
|
||||
|
||||
if controller.amber_parm is not None:
|
||||
file_path = controller.amber_parm
|
||||
self.read_information_from_amberfile(file_path)
|
||||
|
||||
self.Initial_Neighbor_Grid()
|
||||
self.not_first_time = 0
|
||||
self.refresh_count = 0
|
||||
|
||||
self.atom_numbers_in_grid_bucket = Tensor(np.asarray(self.atom_numbers_in_grid_bucket, np.int32), mstype.int32)
|
||||
self.bucket = Tensor(
|
||||
np.asarray(self.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]), mstype.int32)
|
||||
self.grid_N = Tensor(np.asarray(self.grid_N, np.int32), mstype.int32)
|
||||
self.grid_length_inverse = Tensor(np.asarray(self.grid_length_inverse, np.float32), mstype.float32)
|
||||
self.atom_in_grid_serial = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
|
||||
self.pointer = Tensor(np.asarray(self.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32)
|
||||
self.nl_atom_numbers = Tensor(np.zeros(self.atom_numbers, np.int32), mstype.int32)
|
||||
self.nl_atom_serial = Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32)
|
||||
self.excluded_list_start = Tensor(np.asarray(self.excluded_list_start, np.int32), mstype.int32)
|
||||
self.excluded_list = Tensor(np.asarray(self.excluded_list, np.int32), mstype.int32)
|
||||
self.excluded_numbers = Tensor(np.asarray(self.excluded_numbers, np.int32), mstype.int32)
|
||||
self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
|
||||
|
||||
def read_information_from_amberfile(self, file_path):
|
||||
"""read information from amberfile"""
|
||||
file = open(file_path, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.excluded_list_start = [0] * self.atom_numbers
|
||||
self.excluded_numbers = [0] * self.atom_numbers
|
||||
|
||||
for idx, val in enumerate(context):
|
||||
if idx < len(context) - 1:
|
||||
if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
|
||||
start_idx = idx + 2
|
||||
count = 0
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information = []
|
||||
information.extend(value)
|
||||
while count < 11:
|
||||
start_idx += 1
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
self.excluded_atom_numbers = information[10]
|
||||
break
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG NUMBER_EXCLUDED_ATOMS" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < self.atom_numbers:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
count = 0
|
||||
for i in range(self.atom_numbers):
|
||||
self.excluded_numbers[i] = information[i]
|
||||
self.excluded_list_start[i] = count
|
||||
count += information[i]
|
||||
break
|
||||
|
||||
total_count = sum(self.excluded_numbers)
|
||||
self.excluded_list = []
|
||||
for idx, val in enumerate(context):
|
||||
if "%FLAG EXCLUDED_ATOMS_LIST" in val:
|
||||
count = 0
|
||||
start_idx = idx
|
||||
information = []
|
||||
while count < total_count:
|
||||
start_idx += 1
|
||||
if "%FORMAT" in context[start_idx]:
|
||||
continue
|
||||
else:
|
||||
value = list(map(int, context[start_idx].strip().split()))
|
||||
information.extend(value)
|
||||
count += len(value)
|
||||
|
||||
count = 0
|
||||
for i in range(self.atom_numbers):
|
||||
tmp_list = []
|
||||
for _ in range(self.excluded_numbers[i]):
|
||||
tmp_list.append(information[count] - 1)
|
||||
count += 1
|
||||
tmp_list = sorted(tmp_list)
|
||||
self.excluded_list.extend(tmp_list)
|
||||
break
|
||||
|
||||
def fun(self, Nx, Ny, Nz, l, m, temp_grid_serial, count):
|
||||
"""fun to replace the for"""
|
||||
for n in range(-2, 3):
|
||||
xx = Nx + l
|
||||
if xx < 0:
|
||||
xx = xx + self.Nx
|
||||
elif xx >= self.Nx:
|
||||
xx = xx - self.Nx
|
||||
yy = Ny + m
|
||||
if yy < 0:
|
||||
yy = yy + self.Ny
|
||||
elif yy >= self.Ny:
|
||||
yy = yy - self.Ny
|
||||
zz = Nz + n
|
||||
if zz < 0:
|
||||
zz = zz + self.Nz
|
||||
elif zz >= self.Nz:
|
||||
zz = zz - self.Nz
|
||||
temp_grid_serial[count] = zz * self.Nxy + yy * self.Nx + xx
|
||||
count += 1
|
||||
return temp_grid_serial, count
|
||||
|
||||
def Initial_Neighbor_Grid(self):
|
||||
"""initial neighbour grid"""
|
||||
half_cutoff = self.half_cutoff_with_skin
|
||||
self.Nx = int(self.box_length[0] / half_cutoff)
|
||||
self.Ny = int(self.box_length[1] / half_cutoff)
|
||||
self.Nz = int(self.box_length[2] / half_cutoff)
|
||||
self.grid_N = [self.Nx, self.Ny, self.Nz]
|
||||
self.grid_length = [self.box_length[0] / self.Nx, self.box_length[1] / self.Ny, self.box_length[2] / self.Nz]
|
||||
self.grid_length_inverse = [1.0 / self.grid_length[0], 1.0 / self.grid_length[1], 1.0 / self.grid_length[2]]
|
||||
self.Nxy = self.Nx * self.Ny
|
||||
self.grid_numbers = self.Nz * self.Nxy
|
||||
|
||||
self.atom_numbers_in_grid_bucket = [0] * self.grid_numbers
|
||||
self.bucket = [-1] * (self.grid_numbers * self.max_atom_in_grid_numbers)
|
||||
self.pointer = []
|
||||
temp_grid_serial = [0] * 125
|
||||
for i in range(self.grid_numbers):
|
||||
Nz = int(i / self.Nxy)
|
||||
Ny = int((i - self.Nxy * Nz) / self.Nx)
|
||||
Nx = i - self.Nxy * Nz - self.Nx * Ny
|
||||
count = 0
|
||||
for l in range(-2, 3):
|
||||
for m in range(-2, 3):
|
||||
temp_grid_serial, count = self.fun(Nx, Ny, Nz, l, m, temp_grid_serial, count)
|
||||
temp_grid_serial = sorted(temp_grid_serial)
|
||||
self.pointer.extend(temp_grid_serial)
|
||||
|
||||
def NeighborListUpdate(self, crd, old_crd, uint_crd, crd_to_uint_crd_cof, uint_dr_to_dr_cof, box_length,
|
||||
not_first_time=0):
|
||||
"""NeighborList Update"""
|
||||
self.not_first_time = not_first_time
|
||||
self.neighbor_list_update = P.NeighborListUpdate(grid_numbers=self.grid_numbers, atom_numbers=self.atom_numbers,
|
||||
refresh_count=self.refresh_count,
|
||||
not_first_time=self.not_first_time,
|
||||
Nxy=self.Nxy, excluded_atom_numbers=self.excluded_atom_numbers,
|
||||
cutoff_square=self.cutoff_square,
|
||||
half_skin_square=self.half_skin_square,
|
||||
cutoff_with_skin=self.cutoff_with_skin,
|
||||
half_cutoff_with_skin=self.half_cutoff_with_skin,
|
||||
cutoff_with_skin_square=self.cutoff_with_skin_square,
|
||||
refresh_interval=self.refresh_interval, cutoff=self.cutoff,
|
||||
skin=self.skin,
|
||||
max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
|
||||
max_neighbor_numbers=self.max_neighbor_numbers)
|
||||
|
||||
res = self.neighbor_list_update(self.atom_numbers_in_grid_bucket, self.bucket, crd, box_length, self.grid_N,
|
||||
self.grid_length_inverse, self.atom_in_grid_serial, old_crd,
|
||||
crd_to_uint_crd_cof, uint_crd, self.pointer, self.nl_atom_numbers,
|
||||
self.nl_atom_serial, uint_dr_to_dr_cof, self.excluded_list_start,
|
||||
self.excluded_list, self.excluded_numbers, self.need_refresh_flag)
|
||||
return res
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""PME"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Particle_Mesh_Ewald(nn.Cell):
|
||||
"""class Particle_Mesh_Ewald"""
|
||||
|
||||
def __init__(self, controller, md_info):
|
||||
super(Particle_Mesh_Ewald, self).__init__()
|
||||
self.cutoff = 10.0 if "cut" not in controller.Command_Set else float(controller.Command_Set["cut"])
|
||||
self.tolerance = 0.00001 if "PME_Direct_Tolerance" not in controller.Command_Set else float(
|
||||
controller.Command_Set["PME_Direct_Tolerance"])
|
||||
self.fftx = -1 if "fftx" not in controller.Command_Set else int(controller.Command_Set["fftx"])
|
||||
self.ffty = -1 if "ffty" not in controller.Command_Set else int(controller.Command_Set["ffty"])
|
||||
self.fftz = -1 if "fftz" not in controller.Command_Set else int(controller.Command_Set["fftz"])
|
||||
self.atom_numbers = md_info.atom_numbers
|
||||
self.box_length = md_info.box_length
|
||||
|
||||
if self.fftx < 0:
|
||||
self.fftx = self.Get_Fft_Patameter(self.box_length[0])
|
||||
if self.ffty < 0:
|
||||
self.ffty = self.Get_Fft_Patameter(self.box_length[1])
|
||||
if self.fftz < 0:
|
||||
self.fftz = self.Get_Fft_Patameter(self.box_length[2])
|
||||
|
||||
self.beta = self.Get_Beta(self.cutoff, self.tolerance)
|
||||
self.box_length = Tensor(np.asarray(self.box_length, np.float32), mstype.float32)
|
||||
|
||||
print("========== ", self.fftx, self.ffty, self.fftz, self.tolerance, self.beta)
|
||||
|
||||
def Get_Beta(self, cutoff, tolerance):
|
||||
"""get beta"""
|
||||
high = 1.0
|
||||
ihigh = 1
|
||||
while 1:
|
||||
tempf = math.erfc(high * cutoff) / cutoff
|
||||
if tempf <= tolerance:
|
||||
break
|
||||
high *= 2
|
||||
ihigh += 1
|
||||
ihigh += 50
|
||||
low = 0.0
|
||||
for _ in range(1, ihigh):
|
||||
beta = (low + high) / 2
|
||||
tempf = math.erfc(beta * cutoff) / cutoff
|
||||
if tempf >= tolerance:
|
||||
low = beta
|
||||
else:
|
||||
high = beta
|
||||
return beta
|
||||
|
||||
def Check_2357_Factor(self, number):
|
||||
"""check 2357 factor"""
|
||||
while number > 0:
|
||||
if number == 1:
|
||||
return 1
|
||||
tempn = number / 2
|
||||
if tempn * 2 != number:
|
||||
break
|
||||
number = tempn
|
||||
while number > 0:
|
||||
if number == 1:
|
||||
return 1
|
||||
tempn = number / 3
|
||||
if tempn * 3 != number:
|
||||
break
|
||||
number = tempn
|
||||
while number > 0:
|
||||
if number == 1:
|
||||
return 1
|
||||
tempn = number / 5
|
||||
if tempn * 5 != number:
|
||||
break
|
||||
number = tempn
|
||||
while number > 0:
|
||||
if number == 1:
|
||||
return 1
|
||||
tempn = number / 7
|
||||
if tempn * 7 != number:
|
||||
break
|
||||
number = tempn
|
||||
return 0
|
||||
|
||||
def Get_Fft_Patameter(self, length):
|
||||
"""get fft parameter"""
|
||||
tempi = math.ceil(length + 3) >> 2 << 2
|
||||
if 60 <= tempi <= 68:
|
||||
tempi = 64
|
||||
elif 120 <= tempi <= 136:
|
||||
tempi = 128
|
||||
elif 240 <= tempi <= 272:
|
||||
tempi = 256
|
||||
elif 480 <= tempi <= 544:
|
||||
tempi = 512
|
||||
elif 960 <= tempi <= 1088:
|
||||
tempi = 1024
|
||||
while 1:
|
||||
if self.Check_2357_Factor(tempi):
|
||||
return tempi
|
||||
tempi += 4
|
||||
|
||||
def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start,
|
||||
excluded_list, excluded_numbers):
|
||||
"""PME_Energy"""
|
||||
self.pmee = P.PMEEnergy(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
|
||||
self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \
|
||||
self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof,
|
||||
excluded_list_start, excluded_list, excluded_numbers)
|
||||
return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy
|
||||
|
||||
def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list,
|
||||
excluded_numbers):
|
||||
"""PME Excluded Force"""
|
||||
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, beta=self.beta)
|
||||
self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers)
|
||||
return self.frc
|
||||
|
||||
def PME_Reciprocal_Force(self, uint_crd, charge):
|
||||
"""PME reciprocal force"""
|
||||
self.pmerf = P.PMEReciprocalForce(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
|
||||
self.frc = self.pmerf(self.box_length, uint_crd, charge)
|
||||
return self.frc
|
||||
|
||||
def Energy_Device_To_Host(self):
|
||||
"""Energy_Device_To_Host"""
|
||||
self.ee_ene = self.reciprocal_energy + self.self_energy + self.direct_energy + self.correction_energy
|
||||
return self.ee_ene
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""simulation"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, nn
|
||||
|
||||
from Langevin_Liujian_md import Langevin_Liujian
|
||||
from angle import Angle
|
||||
from bond import Bond
|
||||
from dihedral import Dihedral
|
||||
from lennard_jones import Lennard_Jones_Information
|
||||
from md_information import md_information
|
||||
from nb14 import NON_BOND_14
|
||||
from neighbor_list import nb_infomation
|
||||
from particle_mesh_ewald import Particle_Mesh_Ewald
|
||||
|
||||
|
||||
class controller:
|
||||
"""class controller"""
|
||||
def __init__(self, args_opt):
|
||||
self.input_file = args_opt.i
|
||||
self.initial_coordinates_file = args_opt.c
|
||||
self.amber_parm = args_opt.amber_parm
|
||||
self.restrt = args_opt.r
|
||||
self.mdcrd = args_opt.x
|
||||
self.mdout = args_opt.o
|
||||
self.mdbox = args_opt.box
|
||||
|
||||
self.Command_Set = {}
|
||||
self.md_task = None
|
||||
self.commands_from_in_file()
|
||||
|
||||
def commands_from_in_file(self):
|
||||
"""commands from in file"""
|
||||
file = open(self.input_file, 'r')
|
||||
context = file.readlines()
|
||||
file.close()
|
||||
self.md_task = context[0].strip()
|
||||
for val in context:
|
||||
if "=" in val:
|
||||
assert len(val.strip().split("=")) == 2
|
||||
flag, value = val.strip().split("=")
|
||||
value = value.replace(",", '')
|
||||
flag = flag.replace(" ", "")
|
||||
if flag not in self.Command_Set:
|
||||
self.Command_Set[flag] = value
|
||||
else:
|
||||
print("ERROR COMMAND FILE")
|
||||
|
||||
|
||||
class Simulation(nn.Cell):
|
||||
"""class simulation"""
|
||||
|
||||
def __init__(self, args_opt):
|
||||
super(Simulation, self).__init__()
|
||||
self.control = controller(args_opt)
|
||||
self.md_info = md_information(self.control)
|
||||
self.bond = Bond(self.control, self.md_info)
|
||||
self.angle = Angle(self.control)
|
||||
self.dihedral = Dihedral(self.control)
|
||||
self.nb14 = NON_BOND_14(self.control, self.dihedral, self.md_info.atom_numbers)
|
||||
self.nb_info = nb_infomation(self.control, self.md_info.atom_numbers, self.md_info.box_length)
|
||||
self.LJ_info = Lennard_Jones_Information(self.control)
|
||||
self.liujian_info = Langevin_Liujian(self.control, self.md_info.atom_numbers)
|
||||
self.pme_method = Particle_Mesh_Ewald(self.control, self.md_info)
|
||||
self.box_length = Tensor(np.asarray(self.md_info.box_length, np.float32), mstype.float32)
|
||||
self.file = None
|
||||
|
||||
def Main_Before_Calculate_Force(self):
|
||||
"""Main Before Calculate Force"""
|
||||
_ = self.md_info.MD_Information_Crd_To_Uint_Crd()
|
||||
self.md_info.uint_crd_with_LJ = (self.md_info.uint_crd, self.LJ_info.atom_LJ_type, self.md_info.charge)
|
||||
return self.md_info.uint_crd, self.md_info.uint_crd_with_LJ
|
||||
|
||||
def Initial_Neighbor_List_Update(self, not_first_time):
|
||||
"""Initial Neighbor List Update"""
|
||||
res = self.nb_info.NeighborListUpdate(self.md_info.crd, self.md_info.crd_old, self.md_info.uint_crd,
|
||||
self.md_info.crd_to_uint_crd_cof, self.md_info.uint_dr_to_dr_cof,
|
||||
self.box_length, not_first_time)
|
||||
|
||||
return res
|
||||
|
||||
def Main_Calculate_Force(self):
|
||||
"""main calculate force"""
|
||||
self.bond.atom_numbers = self.md_info.atom_numbers
|
||||
md_info = self.md_info
|
||||
LJ_info = self.LJ_info
|
||||
nb_info = self.nb_info
|
||||
pme_method = self.pme_method
|
||||
bond_frc, _ = self.bond.Bond_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
|
||||
frc_t = bond_frc.asnumpy()
|
||||
|
||||
angle_frc, _ = self.angle.Angle_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
|
||||
frc_t += angle_frc.asnumpy()
|
||||
|
||||
dihedral_frc, _ = self.dihedral.Dihedral_Force_With_Atom_Energy(md_info.uint_crd, md_info.uint_dr_to_dr_cof)
|
||||
frc_t += dihedral_frc.asnumpy()
|
||||
|
||||
nb14_frc, _ = self.nb14.Non_Bond_14_LJ_CF_Force_With_Atom_Energy(md_info.uint_crd_with_LJ,
|
||||
md_info.uint_dr_to_dr_cof, LJ_info.LJ_A,
|
||||
LJ_info.LJ_B)
|
||||
frc_t += nb14_frc.asnumpy()
|
||||
|
||||
lj_frc = LJ_info.LJ_Force_With_PME_Direct_Force(
|
||||
md_info.atom_numbers, md_info.uint_crd_with_LJ, md_info.uint_dr_to_dr_cof, nb_info.nl_atom_numbers,
|
||||
nb_info.nl_atom_serial, nb_info.cutoff, pme_method.beta)
|
||||
frc_t += lj_frc.asnumpy()
|
||||
|
||||
pme_excluded_frc = pme_method.PME_Excluded_Force(
|
||||
md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge,
|
||||
nb_info.excluded_list_start, nb_info.excluded_list,
|
||||
nb_info.excluded_numbers)
|
||||
frc_t += pme_excluded_frc.asnumpy()
|
||||
|
||||
pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge)
|
||||
frc_t += pme_reciprocal_frc.asnumpy()
|
||||
|
||||
self.md_info.frc = Tensor(frc_t, mstype.float32)
|
||||
|
||||
return self.md_info.frc
|
||||
|
||||
def Main_Calculate_Energy(self):
|
||||
"""main calculate energy"""
|
||||
_ = self.bond.Bond_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
|
||||
_ = self.angle.Angle_Energy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
|
||||
_ = self.dihedral.Dihedral_Engergy(self.md_info.uint_crd, self.md_info.uint_dr_to_dr_cof)
|
||||
_ = self.nb14.Non_Bond_14_LJ_CF_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
|
||||
self.LJ_info.LJ_A,
|
||||
self.LJ_info.LJ_B)
|
||||
|
||||
_ = self.LJ_info.LJ_Energy(self.md_info.uint_crd_with_LJ, self.md_info.uint_dr_to_dr_cof,
|
||||
self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
|
||||
self.nb_info.cutoff_square)
|
||||
_ = self.pme_method.PME_Energy(
|
||||
self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
|
||||
self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list,
|
||||
self.nb_info.excluded_numbers)
|
||||
_ = self.pme_method.Energy_Device_To_Host()
|
||||
|
||||
def Main_After_Calculate_Energy(self):
|
||||
"""main after calculate energy"""
|
||||
md_info = self.md_info
|
||||
LJ_info = self.LJ_info
|
||||
bond = self.bond
|
||||
angle = self.angle
|
||||
dihedral = self.dihedral
|
||||
nb14 = self.nb14
|
||||
pme_method = self.pme_method
|
||||
|
||||
md_info.total_potential_energy = 0
|
||||
md_info.total_potential_energy += bond.sigma_of_bond_ene
|
||||
md_info.total_potential_energy += angle.sigma_of_angle_ene
|
||||
md_info.total_potential_energy += dihedral.sigma_of_dihedral_ene
|
||||
md_info.total_potential_energy += nb14.nb14_lj_energy_sum + nb14.nb14_cf_energy_sum
|
||||
md_info.total_potential_energy += LJ_info.LJ_energy_sum
|
||||
pme_method.Energy_Device_To_Host()
|
||||
md_info.total_potential_energy += pme_method.ee_ene
|
||||
print("md_info.total_potential_energy", md_info.total_potential_energy)
|
||||
|
||||
def Main_Iteration_2(self):
|
||||
"""main iteration2"""
|
||||
md_info = self.md_info
|
||||
control = self.control
|
||||
liujian_info = self.liujian_info
|
||||
|
||||
if md_info.mode > 0 and int(control.Command_Set["thermostat"]) == 1:
|
||||
md_info.vel, md_info.crd, md_info.frc, md_info.acc = liujian_info.MD_Iteration_Leap_Frog(
|
||||
md_info.d_mass_inverse, md_info.vel, md_info.crd, md_info.frc)
|
||||
|
||||
def Main_After_Iteration(self):
|
||||
"""main after iteration"""
|
||||
md_info = self.md_info
|
||||
nb_info = self.nb_info
|
||||
md_info.Centerize()
|
||||
_ = nb_info.NeighborListUpdate(md_info.crd, md_info.crd_old, md_info.uint_crd,
|
||||
md_info.crd_to_uint_crd_cof,
|
||||
md_info.uint_dr_to_dr_cof, self.box_length, not_first_time=1)
|
||||
|
||||
def Main_Print(self):
|
||||
"""compute the temperature"""
|
||||
md_info = self.md_info
|
||||
temperature = md_info.MD_Information_Temperature()
|
||||
md_info.h_temperature = temperature
|
||||
steps = md_info.steps
|
||||
temperature = temperature.asnumpy()
|
||||
total_potential_energy = md_info.total_potential_energy.asnumpy()
|
||||
sigma_of_bond_ene = self.bond.sigma_of_bond_ene.asnumpy()
|
||||
sigma_of_angle_ene = self.angle.sigma_of_angle_ene.asnumpy()
|
||||
sigma_of_dihedral_ene = self.dihedral.sigma_of_dihedral_ene.asnumpy()
|
||||
nb14_lj_energy_sum = self.nb14.nb14_lj_energy_sum.asnumpy()
|
||||
nb14_cf_energy_sum = self.nb14.nb14_cf_energy_sum.asnumpy()
|
||||
LJ_energy_sum = self.LJ_info.LJ_energy_sum.asnumpy()
|
||||
ee_ene = self.pme_method.ee_ene.asnumpy()
|
||||
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
|
||||
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), end=" ")
|
||||
if self.bond.bond_numbers > 0:
|
||||
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
|
||||
if self.angle.angle_numbers > 0:
|
||||
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
|
||||
if self.dihedral.dihedral_numbers > 0:
|
||||
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
|
||||
if self.nb14.nb14_numbers > 0:
|
||||
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
|
||||
|
||||
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
|
||||
print("{:>12.3f}".format(float(ee_ene)))
|
||||
|
||||
if self.file is not None:
|
||||
self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
|
||||
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
|
||||
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
|
||||
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
|
||||
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
|
||||
|
||||
return temperature
|
||||
|
||||
def Main_Initial(self):
|
||||
"""main initial"""
|
||||
if self.control.mdout:
|
||||
self.file = open(self.control.mdout, 'w')
|
||||
self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
|
||||
|
||||
def Main_Destroy(self):
|
||||
"""main destroy"""
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
print("Save successfully!")
|
Loading…
Reference in New Issue