forked from mindspore-Ecosystem/mindspore
!9458 Add MD simulation in model zoo
From: @zhangxinfeng3 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
228e889d00
|
@ -0,0 +1,110 @@
|
|||
# Contents
|
||||
|
||||
- [Description](#description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Result](#result)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## Description
|
||||
|
||||
Molecular Dynamics (MD) is playing an increasingly important role in the research of biology, pharmacy, chemistry, and materials science. The architecture is based on DeePMD, which using an NN scheme for MD simulations, which overcomes the limitations associated to auxiliary quantities like the symmetry functions or the Coulomb matrix. Each environment contains a number of atoms, whose local coordinates are arranged in a symmetry preserving way following the prescription of the Deep Potential method. According to the atomic position, atomic types and box tensor to construct energy, force and virial.
|
||||
|
||||
Thanks a lot for DeePMD team's help.
|
||||
|
||||
[1] Paper: L Zhang, J Han, H Wang, R Car, W E. Deep potential molecular dynamics: a scalable model with the accuracy of quantum mechanics. Physical review letters 120 (14), 143001 (2018).
|
||||
|
||||
[2] Paper: H Wang, L Zhang, J Han, W E. DeePMD-kit: A deep learning package for many-body potential energy representation and molecular dynamics. Computer Physics Communications 228, 178-184 (2018).
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The overall network architecture of MD simulation is show below.
|
||||
|
||||
[Link](https://arxiv.org/abs/1707.09571)
|
||||
|
||||
## Dataset
|
||||
|
||||
Dataset used: deepmodeling/deepmd-kit/examples/water/data
|
||||
|
||||
The data is generated by Quantum Espresso and the input of Quantum Espresso is setted manually.
|
||||
|
||||
The directory structure of the data is as follows:
|
||||
|
||||
```text
|
||||
└─data
|
||||
├─type.raw
|
||||
├─set.000
|
||||
│ ├──box.npy
|
||||
│ ├──coord.npy
|
||||
│ ├──energy.npy
|
||||
│ └──force.npy
|
||||
├─set.001
|
||||
├─set.002
|
||||
└─set.003
|
||||
```
|
||||
|
||||
In `deepmodeling/deepmd-kit/source`:
|
||||
|
||||
- Use `train/DataSystem.py` to get coord and atype.
|
||||
- Use function compute_input_stats in `train/DataSystem.py` to get avg and std.
|
||||
- Use `op/descrpt_se_a.cc` to get nlist.
|
||||
- Save coord, atype, avg, std and nlist as `Npz` file for infer.
|
||||
|
||||
## Environment Requirements
|
||||
|
||||
- Hardware (Ascend)
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## Script Description
|
||||
|
||||
### Script and Sample Code
|
||||
|
||||
```shell
|
||||
├── md
|
||||
├── README.md # descriptions about MD
|
||||
├── script
|
||||
│ ├── eval.sh # evaluation script
|
||||
├── src
|
||||
│ ├── descriptor.py # descriptor function
|
||||
│ ├── virial.py # calculating virial function
|
||||
│ └── network.py # MD simulation architecture
|
||||
└── eval.py # evaluation interface
|
||||
```
|
||||
|
||||
### Training Process
|
||||
|
||||
To Be Done
|
||||
|
||||
### Evaluation Process
|
||||
|
||||
After installing MindSpore via the official website, you can start evaluation as follows:
|
||||
|
||||
```shell
|
||||
python eval.py --dataset_path [DATASET_PATH] --checkpoint_path [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
> checkpoint can be trained by using DeePMD-kit, and convert into the ckpt of MindSpore.
|
||||
|
||||
### Result
|
||||
|
||||
推理的结果如下:
|
||||
|
||||
```text
|
||||
atom_ener: -94.38766 -94.294426 -94.39194 -94.70758 -94.51311 -94.457954 ...
|
||||
force: 1.64911175 -1.09822524 0.46055657 -1.34915102 -0.33827361 -0.97184098 ...
|
||||
virial: -11.736662 -4.286214 2.8852937 -4.286209 -10.408775 -5.6738234 ...
|
||||
```
|
||||
|
||||
## ModelZoo Homepage
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval."""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.network import Network
|
||||
|
||||
parser = argparse.ArgumentParser(description='MD Simulation')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# get input data
|
||||
r = np.load(args_opt.dataset_path)
|
||||
d_coord, d_nlist, avg, std, atype = r['d_coord'], r['d_nlist'], r['avg'], r['std'], r['atype']
|
||||
batch_size = 1
|
||||
atype_tensor = Tensor(atype)
|
||||
avg_tensor = Tensor(avg)
|
||||
std_tensor = Tensor(std)
|
||||
d_coord_tensor = Tensor(np.reshape(d_coord, (1, -1, 3)))
|
||||
d_nlist_tensor = Tensor(d_nlist)
|
||||
frames = []
|
||||
for i in range(batch_size):
|
||||
frames.append(i * 1536)
|
||||
frames = Tensor(frames)
|
||||
# evaluation
|
||||
net = Network()
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.to_float(mstype.float32)
|
||||
energy, atom_ener, virial = \
|
||||
net(d_coord_tensor, d_nlist_tensor, frames, avg_tensor, std_tensor, atype_tensor)
|
||||
print(energy)
|
||||
print(atom_ener)
|
||||
print(virial)
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# eval script
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../eval.py --dataset_path=./$DATA_PATH --checkpoint_path=./$CKPT_PATH > log.txt 2>&1 &
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""The construction of the descriptor."""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ComputeRij(nn.Cell):
|
||||
"""compute rij."""
|
||||
def __init__(self):
|
||||
super(ComputeRij, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.rsum = P.ReduceSum()
|
||||
self.broadcastto = P.BroadcastTo((1, 192 * 138))
|
||||
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 3))
|
||||
self.expdims = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.gather = P.GatherV2()
|
||||
self.mul = P.Mul()
|
||||
self.slice = P.Slice()
|
||||
|
||||
def construct(self, d_coord_tensor, nlist_tensor, frames):
|
||||
"""construct function."""
|
||||
d_coord_tensor = self.cast(d_coord_tensor, mstype.float32)
|
||||
d_coord_tensor = self.reshape(d_coord_tensor, (1, -1, 3))
|
||||
coord_tensor = self.slice(d_coord_tensor, (0, 0, 0), (1, 192, 3))
|
||||
|
||||
nlist_tensor = self.cast(nlist_tensor, mstype.int32)
|
||||
nlist_tensor = self.reshape(nlist_tensor, (1, 192, 138))
|
||||
|
||||
b_nlist = nlist_tensor > -1
|
||||
b_nlist = self.cast(b_nlist, mstype.int32)
|
||||
nlist_tensor_r = b_nlist * nlist_tensor
|
||||
nlist_tensor_r = self.reshape(nlist_tensor_r, (-1,))
|
||||
|
||||
frames = self.cast(frames, mstype.int32)
|
||||
frames = self.expdims(frames, 1)
|
||||
frames = self.broadcastto(frames)
|
||||
frames = self.reshape(frames, (-1,))
|
||||
|
||||
nlist_tensor_r = nlist_tensor_r + frames
|
||||
nlist_tensor_r = self.reshape(nlist_tensor_r, (-1,))
|
||||
|
||||
d_coord_tensor = self.reshape(d_coord_tensor, (-1, 3))
|
||||
selected_coord = self.gather(d_coord_tensor, nlist_tensor_r, 0)
|
||||
selected_coord = self.reshape(selected_coord, (1, 192, 138, 3))
|
||||
|
||||
coord_tensor_expanded = self.expdims(coord_tensor, 2)
|
||||
coord_tensor_expanded = self.broadcastto1(coord_tensor_expanded)
|
||||
|
||||
result_rij_m = selected_coord - coord_tensor_expanded
|
||||
|
||||
b_nlist_expanded = self.expdims(b_nlist, 3)
|
||||
b_nlist_expanded = self.broadcastto1(b_nlist_expanded)
|
||||
|
||||
result_rij = result_rij_m * b_nlist_expanded
|
||||
|
||||
return result_rij
|
||||
|
||||
|
||||
class ComputeDescriptor(nn.Cell):
|
||||
"""compute descriptor."""
|
||||
def __init__(self):
|
||||
super(ComputeDescriptor, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.rsum = P.ReduceSum()
|
||||
self.broadcastto = P.BroadcastTo((1, 192 * 138))
|
||||
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 3))
|
||||
self.broadcastto2 = P.BroadcastTo((1, 192, 138, 3, 3))
|
||||
self.broadcastto3 = P.BroadcastTo((1, 192, 138, 4))
|
||||
self.broadcastto4 = P.BroadcastTo((1, 192, 138, 4, 3))
|
||||
|
||||
self.expdims = P.ExpandDims()
|
||||
self.concat = P.Concat(axis=3)
|
||||
self.gather = P.GatherV2()
|
||||
self.mul = P.Mul()
|
||||
self.slice = P.Slice()
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.ones = P.OnesLike()
|
||||
self.eye = P.Eye()
|
||||
|
||||
def construct(self, rij_tensor, avg_tensor, std_tensor, nlist_tensor, atype_tensor, r_min=5.8, r_max=6.0):
|
||||
"""construct function."""
|
||||
nlist_tensor = self.reshape(nlist_tensor, (1, 192, 138))
|
||||
b_nlist = nlist_tensor > -1
|
||||
b_nlist = self.cast(b_nlist, mstype.int32)
|
||||
b_nlist_expanded = self.expdims(b_nlist, 3)
|
||||
b_nlist_4 = self.broadcastto3(b_nlist_expanded)
|
||||
b_nlist_3 = self.broadcastto1(b_nlist_expanded)
|
||||
b_nlist_expanded = self.expdims(b_nlist_expanded, 4)
|
||||
b_nlist_33 = self.broadcastto2(b_nlist_expanded)
|
||||
|
||||
rij_tensor = rij_tensor + self.cast(1 - b_nlist_3, mstype.float32)
|
||||
|
||||
r_2 = self.square(rij_tensor)
|
||||
d_2 = self.rsum(r_2, 3)
|
||||
invd_2 = self.inv(d_2)
|
||||
invd = self.sqrt(invd_2)
|
||||
invd_4 = self.square(invd_2)
|
||||
d = invd * d_2
|
||||
invd_3 = invd_4 * d
|
||||
|
||||
b_d_1 = self.cast(d < r_max, mstype.int32)
|
||||
b_d_2 = self.cast(d < r_min, mstype.int32)
|
||||
b_d_3 = self.cast(d >= r_min, mstype.int32)
|
||||
|
||||
du = 1.0 / (r_max - r_min)
|
||||
uu = (d - r_min) * du
|
||||
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
|
||||
dd = (3 * uu * uu * (-6 * uu * uu + 15 * uu - 10) + uu * uu * uu * (-12 * uu + 15)) * du
|
||||
|
||||
sw = vv * b_d_3 * b_d_1 + b_d_2
|
||||
dsw = dd * b_d_3 * b_d_1
|
||||
|
||||
invd_2_e = self.expdims(invd_2, 3)
|
||||
invd_2_e = self.broadcastto1(invd_2_e)
|
||||
descrpt_1 = rij_tensor * invd_2_e
|
||||
|
||||
factor0 = invd_3 * sw - invd_2 * dsw
|
||||
factor0 = self.expdims(factor0, 3)
|
||||
factor0 = self.broadcastto1(factor0)
|
||||
descrpt_deriv_0 = rij_tensor * factor0
|
||||
descrpt_deriv_0 = descrpt_deriv_0 * b_nlist_3
|
||||
descrpt_deriv_0 = self.expdims(descrpt_deriv_0, 3)
|
||||
|
||||
factor1_0 = self.eye(3, 3, mstype.float32)
|
||||
factor1_0 = self.expdims(factor1_0, 0)
|
||||
factor1_0 = self.expdims(factor1_0, 0)
|
||||
factor1_0 = self.expdims(factor1_0, 0)
|
||||
factor1_1 = self.expdims(invd_2 * sw, 3)
|
||||
factor1_1 = self.expdims(factor1_1, 4)
|
||||
descrpt_deriv_1_0 = factor1_0 * factor1_1
|
||||
|
||||
rij_tensor_e1 = self.expdims(rij_tensor, 4)
|
||||
rij_tensor_e2 = self.expdims(rij_tensor, 3)
|
||||
rij_tensor_e1 = self.broadcastto2(rij_tensor_e1)
|
||||
rij_tensor_e2 = self.broadcastto2(rij_tensor_e2)
|
||||
|
||||
factor1_3 = self.expdims(2.0 * invd_4 * sw, 3)
|
||||
factor1_3 = self.expdims(factor1_3, 4)
|
||||
factor1_3 = self.broadcastto2(factor1_3)
|
||||
descrpt_deriv_1_1 = factor1_3 * rij_tensor_e1 * rij_tensor_e2
|
||||
|
||||
factor1_4 = self.expdims(invd * dsw, 3)
|
||||
factor1_4 = self.expdims(factor1_4, 3)
|
||||
factor1_4 = self.broadcastto2(factor1_4)
|
||||
descrpt_1_e = self.expdims(descrpt_1, 4)
|
||||
descrpt_1_e = self.broadcastto2(descrpt_1_e)
|
||||
descrpt_deriv_1_2 = descrpt_1_e * rij_tensor_e2 * factor1_4
|
||||
|
||||
descrpt_deriv_1 = (descrpt_deriv_1_1 - descrpt_deriv_1_0 - descrpt_deriv_1_2) * b_nlist_33
|
||||
|
||||
descrpt_deriv = self.concat((descrpt_deriv_0, descrpt_deriv_1))
|
||||
|
||||
invd_e = self.expdims(invd, 3)
|
||||
descrpt = self.concat((invd_e, descrpt_1))
|
||||
sw = self.broadcastto3(self.expdims(sw, 3))
|
||||
descrpt = descrpt * sw * b_nlist_4
|
||||
|
||||
avg_tensor = self.cast(avg_tensor, mstype.float32)
|
||||
std_tensor = self.cast(std_tensor, mstype.float32)
|
||||
|
||||
atype_tensor = self.reshape(atype_tensor, (-1,))
|
||||
atype_tensor = self.cast(atype_tensor, mstype.int32)
|
||||
avg_tensor = self.gather(avg_tensor, atype_tensor, 0)
|
||||
std_tensor = self.gather(std_tensor, atype_tensor, 0)
|
||||
avg_tensor = self.reshape(avg_tensor, (1, 192, 138, 4))
|
||||
std_tensor = self.reshape(std_tensor, (1, 192, 138, 4))
|
||||
|
||||
std_tensor_2 = self.expdims(std_tensor, 4)
|
||||
std_tensor_2 = self.broadcastto4(std_tensor_2)
|
||||
|
||||
descrpt = (descrpt - avg_tensor) / std_tensor
|
||||
descrpt_deriv = descrpt_deriv / std_tensor_2
|
||||
|
||||
return descrpt, descrpt_deriv
|
||||
|
||||
|
||||
class DescriptorSeA(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DescriptorSeA, self).__init__()
|
||||
self.compute_rij = ComputeRij()
|
||||
self.compute_descriptor = ComputeDescriptor()
|
||||
|
||||
def construct(self, coord, nlist, frames, avg, std, atype):
|
||||
rij = self.compute_rij(coord, nlist, frames)
|
||||
descrpt, descrpt_deriv = self.compute_descriptor(rij, avg, std, nlist, atype)
|
||||
return rij, descrpt, descrpt_deriv
|
|
@ -0,0 +1,250 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""The construction of network for molecular dynamics."""
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Parameter
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from virial import ProdVirialSeA
|
||||
from descriptor import DescriptorSeA
|
||||
|
||||
natoms = [192, 192, 64, 128]
|
||||
rcut_a = -1
|
||||
rcut_r = 6.0
|
||||
rcut_r_smth = 5.8
|
||||
sel_a = [46, 92]
|
||||
sel_r = [0, 0]
|
||||
ntypes = len(sel_a)
|
||||
nnei_a = 138
|
||||
nnei_r = 0
|
||||
nnei = nnei_a + nnei_r
|
||||
ndescrpt_a = nnei_a * 4
|
||||
ndescrpt_r = nnei_r * 1
|
||||
ndescrpt = ndescrpt_a + ndescrpt_r
|
||||
filter_neuron = [25, 50, 100]
|
||||
n_axis_neuron = 16
|
||||
dim_descrpt = filter_neuron[-1] * 16
|
||||
n_neuron = [240, 240, 240]
|
||||
type_bias_ae = [-93.57, -187.15]
|
||||
|
||||
|
||||
class MDNet(nn.Cell):
|
||||
"""MD simulation network."""
|
||||
def __init__(self):
|
||||
super(MDNet, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.concat0 = P.Concat(axis=0)
|
||||
self.tanh = nn.Tanh()
|
||||
self.mat = P.MatMul()
|
||||
self.batchmat = nn.MatMul()
|
||||
self.batchmat_tran = nn.MatMul(transpose_x1=True)
|
||||
self.idt1 = Parameter(Tensor(np.random.normal(0.1, 0.001, (240,)), dtype=mstype.float32), name="type0_idt1")
|
||||
self.idt2 = Parameter(Tensor(np.random.normal(0.1, 0.001, (240,)), dtype=mstype.float32), name="type0_idt2")
|
||||
self.idt3 = Parameter(Tensor(np.random.normal(0.1, 0.001, (240,)), dtype=mstype.float32), name="type1_idt1")
|
||||
self.idt4 = Parameter(Tensor(np.random.normal(0.1, 0.001, (240,)), dtype=mstype.float32), name="type1_idt2")
|
||||
self.idt = [self.idt1, self.idt2, self.idt3, self.idt4]
|
||||
self.neuron = [dim_descrpt] + n_neuron
|
||||
self.par = [1] + filter_neuron
|
||||
self.process = Processing()
|
||||
fc = []
|
||||
for i in range(3):
|
||||
fc.append(nn.Dense(self.par[i], self.par[i + 1],
|
||||
weight_init=Tensor(np.random.normal(0.0, 1.0 / np.sqrt(self.par[i] + self.par[i + 1]),
|
||||
(self.par[i + 1], self.par[i])),
|
||||
dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.par[i + 1],)), dtype=mstype.float32)))
|
||||
for i in range(1, 3):
|
||||
fc.append(nn.Dense(self.par[i], self.par[i + 1],
|
||||
weight_init=Tensor(np.random.normal(0.0, 1.0 / np.sqrt(self.par[i] + self.par[i + 1]),
|
||||
(self.par[i + 1], self.par[i])),
|
||||
dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.par[i + 1],)), dtype=mstype.float32)))
|
||||
for i in range(3):
|
||||
fc.append(nn.Dense(self.par[i], self.par[i + 1],
|
||||
weight_init=Tensor(np.random.normal(0.0, 1.0 / np.sqrt(self.par[i] + self.par[i + 1]),
|
||||
(self.par[i + 1], self.par[i])),
|
||||
dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.par[i + 1],)), dtype=mstype.float32)))
|
||||
for i in range(1, 3):
|
||||
fc.append(nn.Dense(self.par[i], self.par[i + 1],
|
||||
weight_init=Tensor(np.random.normal(0.0, 1.0 / np.sqrt(self.par[i] + self.par[i + 1]),
|
||||
(self.par[i + 1], self.par[i])),
|
||||
dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.par[i + 1],)), dtype=mstype.float32)))
|
||||
self.fc = nn.CellList(fc)
|
||||
self.fc0 = deepcopy(self.fc)
|
||||
self.fc2 = [self.fc, self.fc0]
|
||||
|
||||
fc = []
|
||||
for i in range(3):
|
||||
fc.append(nn.Dense(self.neuron[i], self.neuron[i + 1],
|
||||
weight_init=Tensor(
|
||||
np.random.normal(0.0, 1.0 / np.sqrt(self.neuron[i] + self.neuron[i + 1]),
|
||||
(self.neuron[i + 1], self.neuron[i])), dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.neuron[i + 1],)),
|
||||
dtype=mstype.float32)))
|
||||
fc.append(nn.Dense(240, 1,
|
||||
weight_init=Tensor(
|
||||
np.random.normal(0.0, 1.0 / np.sqrt(240 + 1), (1, 240)), dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(type_bias_ae[0], 1.0, (1,)), dtype=mstype.float32)))
|
||||
for i in range(3):
|
||||
fc.append(nn.Dense(self.neuron[i], self.neuron[i + 1],
|
||||
weight_init=Tensor(
|
||||
np.random.normal(0.0, 1.0 / np.sqrt(self.neuron[i] + self.neuron[i + 1]),
|
||||
(self.neuron[i + 1], self.neuron[i])), dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(0.0, 1.0, (self.neuron[i + 1],)),
|
||||
dtype=mstype.float32)))
|
||||
fc.append(nn.Dense(240, 1,
|
||||
weight_init=Tensor(
|
||||
np.random.normal(0.0, 1.0 / np.sqrt(240 + 1), (1, 240)), dtype=mstype.float32),
|
||||
bias_init=Tensor(np.random.normal(type_bias_ae[1], 1.0, (1,)), dtype=mstype.float32)))
|
||||
self.fc1 = nn.CellList(fc)
|
||||
|
||||
xyz_A = np.vstack((np.identity(46), np.zeros([92, 46])))
|
||||
self.xyz_A = Tensor(np.reshape(xyz_A, (1, 138, 46)))
|
||||
xyz_B = np.vstack((np.zeros([46, 92]), np.identity(92)))
|
||||
self.xyz_B = Tensor(np.reshape(xyz_B, (1, 138, 92)))
|
||||
|
||||
xyz_2 = np.vstack((np.identity(n_axis_neuron), np.zeros([self.par[-1] - n_axis_neuron, n_axis_neuron])))
|
||||
self.xyz_2 = Tensor(xyz_2)
|
||||
|
||||
def _filter(self, slice_0, slice_1, inputs, fc):
|
||||
"""filter method."""
|
||||
shape = self.shape(inputs)
|
||||
slice_inputs = (slice_0, slice_1)
|
||||
xyz_scatter_total = []
|
||||
for type_i in range(2):
|
||||
xyz_scatter = slice_inputs[type_i]
|
||||
shape_i = self.shape(xyz_scatter)
|
||||
xyz_scatter = self.reshape(xyz_scatter, (-1, 1))
|
||||
xyz_scatter = self.tanh(fc[type_i * 5 + 0](xyz_scatter))
|
||||
hidden = self.tanh(fc[type_i * 5 + 1](xyz_scatter))
|
||||
xyz_scatter = fc[type_i * 5 + 3](xyz_scatter) + hidden
|
||||
hidden = self.tanh(fc[type_i * 5 + 2](xyz_scatter))
|
||||
xyz_scatter = fc[type_i * 5 + 4](xyz_scatter) + hidden
|
||||
xyz_scatter = self.reshape(xyz_scatter, (-1, shape_i[1], 100))
|
||||
xyz_scatter_total.append(xyz_scatter)
|
||||
xyz_scatter = self.batchmat(self.xyz_A, xyz_scatter_total[0]) + self.batchmat(self.xyz_B, xyz_scatter_total[1])
|
||||
xyz_scatter_1 = self.batchmat_tran(inputs, xyz_scatter)
|
||||
xyz_scatter_1 = xyz_scatter_1 * (4.0 / (shape[1] * shape[2]))
|
||||
xyz_scatter_2 = self.batchmat(xyz_scatter_1, self.xyz_2)
|
||||
result = self.batchmat_tran(xyz_scatter_1, xyz_scatter_2)
|
||||
return result
|
||||
|
||||
def _fitting(self, slice0, slice1, h, slice2, slice3, o):
|
||||
"""fitting method."""
|
||||
l_layer = []
|
||||
slice_data = (slice0, slice1, h, slice2, slice3, o)
|
||||
for type_i in range(2):
|
||||
layer = self._filter(slice_data[3 * type_i], slice_data[3 * type_i + 1], slice_data[3 * type_i + 2],
|
||||
self.fc2[type_i])
|
||||
layer = self.reshape(layer, (-1, dim_descrpt))
|
||||
layer = self.tanh(self.fc1[type_i * 4 + 0](layer))
|
||||
layer = layer + self.tanh(self.fc1[type_i * 4 + 1](layer)) * self.idt[2 * type_i + 0]
|
||||
layer = layer + self.tanh(self.fc1[type_i * 4 + 2](layer)) * self.idt[2 * type_i + 1]
|
||||
final_layer = self.fc1[type_i * 4 + 3](layer)
|
||||
l_layer.append(final_layer)
|
||||
outs = self.concat0((l_layer[0], l_layer[1]))
|
||||
return self.reshape(outs, (-1,))
|
||||
|
||||
def construct(self, inputs):
|
||||
"""construct function."""
|
||||
slice0, slice1, h, slice2, slice3, o = self.process(inputs)
|
||||
dout = self._fitting(slice0, slice1, h, slice2, slice3, o)
|
||||
return dout
|
||||
|
||||
|
||||
class Processing(nn.Cell):
|
||||
"""data process."""
|
||||
def __init__(self):
|
||||
super(Processing, self).__init__()
|
||||
self.slice = P.Slice()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.batchmat = nn.MatMul()
|
||||
self.split = P.Split(1, 3)
|
||||
self.concat = P.Concat(axis=1)
|
||||
slice_64 = Tensor(np.hstack((np.identity(64), np.zeros([64, 128]))))
|
||||
slice_128 = Tensor(np.hstack((np.zeros([128, 64]), np.identity(128))))
|
||||
self.slice_0 = [slice_64, slice_128]
|
||||
slice_46 = Tensor(np.hstack((np.identity(46), np.zeros([46, 92]))))
|
||||
slice_92 = Tensor(np.hstack((np.zeros([92, 46]), np.identity(92))))
|
||||
self.slice_1 = [slice_46, slice_92]
|
||||
slice_2 = np.vstack((np.identity(1), np.zeros([3, 1])))
|
||||
self.slice_2 = Tensor(slice_2)
|
||||
|
||||
def construct(self, inputs):
|
||||
"""construct function."""
|
||||
slice_data = []
|
||||
split_tensor = self.split(inputs)
|
||||
split_64 = self.reshape(split_tensor[0], (-1, 138, 4))
|
||||
split_128 = self.reshape(self.concat((split_tensor[1], split_tensor[2])), (-1, 138, 4))
|
||||
split_t = (split_64, split_128)
|
||||
for type_i in range(2):
|
||||
for type_j in range(2):
|
||||
inputs_reshape = self.batchmat(self.slice_1[type_j], split_t[type_i])
|
||||
xyz_scatter = self.batchmat(inputs_reshape, self.slice_2)
|
||||
slice_data.append(xyz_scatter)
|
||||
slice_data.append(split_t[type_i])
|
||||
slice0, slice1, h, slice2, slice3, o = slice_data[0], slice_data[1], slice_data[2], \
|
||||
slice_data[3], slice_data[4], slice_data[5]
|
||||
return slice0, slice1, h, slice2, slice3, o
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
gout = self.grad(self.network)(x)
|
||||
return gout
|
||||
|
||||
|
||||
class Network(nn.Cell):
|
||||
"""The network to calculate energy, force and virial."""
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.sum = P.ReduceSum()
|
||||
self.mdnet = MDNet()
|
||||
self.grad = Grad(self.mdnet)
|
||||
self.descrpt_se_a = DescriptorSeA()
|
||||
self.process = Processing()
|
||||
self.prod_virial_se_a = ProdVirialSeA()
|
||||
|
||||
def construct(self, coord, nlist, frames, avg, std, atype):
|
||||
"""construct function."""
|
||||
rij, descrpt, descrpt_deriv = \
|
||||
self.descrpt_se_a(coord, nlist, frames, avg, std, atype)
|
||||
# calculate energy and atom_ener
|
||||
atom_ener = self.mdnet(descrpt)
|
||||
energy_raw = atom_ener
|
||||
energy_raw = self.reshape(energy_raw, (-1, natoms[0]))
|
||||
energy = self.sum(energy_raw, 1)
|
||||
# grad of atom_ener
|
||||
net_deriv = self.grad(descrpt)
|
||||
net_deriv_reshape = self.reshape(net_deriv[0], (-1, natoms[0], ndescrpt))
|
||||
descrpt_deriv_reshape = self.reshape(descrpt_deriv, (-1, natoms[0], ndescrpt, 3))
|
||||
# calculate virial
|
||||
virial = self.prod_virial_se_a(net_deriv_reshape, descrpt_deriv_reshape, rij, nlist)
|
||||
return energy, atom_ener, virial
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Calculate virial of atoms."""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ProdVirialSeA(nn.Cell):
|
||||
"""calculate virial."""
|
||||
def __init__(self):
|
||||
super(ProdVirialSeA, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.rsum = P.ReduceSum()
|
||||
self.rksum = P.ReduceSum(keep_dims=True)
|
||||
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 4, 3, 3))
|
||||
self.broadcastto2 = P.BroadcastTo((1, 192, 138, 4, 3))
|
||||
self.broadcastto3 = P.BroadcastTo((1, 192, 138, 3))
|
||||
self.expdims = P.ExpandDims()
|
||||
|
||||
def construct(self, net_deriv_reshape, descrpt_deriv, rij, nlist):
|
||||
"""construct function."""
|
||||
descrpt_deriv = self.cast(descrpt_deriv, mstype.float32)
|
||||
descrpt_deriv = self.reshape(descrpt_deriv, (1, 192, 138, 4, 3))
|
||||
|
||||
net_deriv_reshape = self.cast(net_deriv_reshape, mstype.float32)
|
||||
net_deriv_reshape = self.reshape(net_deriv_reshape, (1, 192, 138, 4))
|
||||
net_deriv_reshape = self.expdims(net_deriv_reshape, 4)
|
||||
net_deriv_reshape = self.broadcastto2(net_deriv_reshape)
|
||||
|
||||
rij = self.cast(rij, mstype.float32)
|
||||
rij = self.reshape(rij, (1, 192, 138, 3))
|
||||
rij = self.expdims(rij, 3)
|
||||
rij = self.expdims(rij, 4)
|
||||
rij = self.broadcastto1(rij)
|
||||
|
||||
nlist = self.cast(nlist, mstype.int32)
|
||||
nlist = self.reshape(nlist, (1, 192, 138))
|
||||
nlist = self.expdims(nlist, 3)
|
||||
nlist = self.broadcastto3(nlist)
|
||||
|
||||
tmp = descrpt_deriv * net_deriv_reshape
|
||||
|
||||
b_blist = self.cast(nlist > -1, mstype.int32)
|
||||
b_blist = self.expdims(b_blist, 3)
|
||||
b_blist = self.broadcastto2(b_blist)
|
||||
|
||||
tmp_1 = tmp * b_blist
|
||||
tmp_1 = self.expdims(tmp_1, 5)
|
||||
tmp_1 = self.broadcastto1(tmp_1)
|
||||
|
||||
out = tmp_1 * rij
|
||||
out = self.rsum(out, (1, 2, 3))
|
||||
return out
|
Loading…
Reference in New Issue