!11666 Adjust the output of MD simulation
From: @zhangxinfeng3 Reviewed-by: @ckey_dou,@wang_zi_dong Signed-off-by: @wang_zi_dong
This commit is contained in:
commit
3a65a528b9
|
@ -13,7 +13,7 @@
|
|||
|
||||
## Description
|
||||
|
||||
Molecular Dynamics (MD) is playing an increasingly important role in the research of biology, pharmacy, chemistry, and materials science. The architecture is based on DeePMD, which using an NN scheme for MD simulations, which overcomes the limitations associated to auxiliary quantities like the symmetry functions or the Coulomb matrix. Each environment contains a number of atoms, whose local coordinates are arranged in a symmetry preserving way following the prescription of the Deep Potential method. According to the atomic position, atomic types and box tensor to construct energy, force and virial.
|
||||
Molecular Dynamics (MD) is playing an increasingly important role in the research of biology, pharmacy, chemistry, and materials science. The architecture is based on DeePMD, which using an NN scheme for MD simulations, which overcomes the limitations associated to auxiliary quantities like the symmetry functions or the Coulomb matrix. Each environment contains a number of atoms, whose local coordinates are arranged in a symmetry preserving way following the prescription of the Deep Potential method. According to the atomic position, atomic types and box tensor to construct energy.
|
||||
|
||||
Thanks a lot for DeePMD team's help.
|
||||
|
||||
|
@ -75,7 +75,6 @@ In `deepmodeling/deepmd-kit/source`:
|
|||
│ ├── eval.sh # evaluation script
|
||||
├── src
|
||||
│ ├── descriptor.py # descriptor function
|
||||
│ ├── virial.py # calculating virial function
|
||||
│ └── network.py # MD simulation architecture
|
||||
└── eval.py # evaluation interface
|
||||
```
|
||||
|
@ -101,8 +100,6 @@ The infer result:
|
|||
```text
|
||||
energy: -29944.03
|
||||
atom_energy: -94.38766 -94.294426 -94.39194 -94.70758 -94.51311 -94.457954 ...
|
||||
force: 1.649112 -1.0982257 0.46055675 -1.3491507 -0.3382736 -0.97184074 ...
|
||||
virial: -11.736662 -4.2862144 2.8852935 -4.286209 -10.408775 -5.6738224 ...
|
||||
```
|
||||
|
||||
## ModelZoo Homepage
|
||||
|
|
|
@ -49,9 +49,7 @@ if __name__ == '__main__':
|
|||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.to_float(mstype.float32)
|
||||
energy, atom_ener, force, virial = \
|
||||
energy, atom_ener, _ = \
|
||||
net(d_coord_tensor, d_nlist_tensor, frames, avg_tensor, std_tensor, atype_tensor, nlist_tensor)
|
||||
print('energy:', energy)
|
||||
print('atom_energy:', atom_ener)
|
||||
print('force:', force)
|
||||
print('virial:', virial)
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore import Tensor
|
|||
from mindspore import nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from .virial import ProdVirialSeA
|
||||
from .descriptor import DescriptorSeA
|
||||
|
||||
natoms = [192, 192, 64, 128]
|
||||
|
@ -230,12 +229,10 @@ class Network(nn.Cell):
|
|||
self.grad = Grad(self.mdnet)
|
||||
self.descrpt_se_a = DescriptorSeA()
|
||||
self.process = Processing()
|
||||
self.prod_virial_se_a = ProdVirialSeA()
|
||||
self.prod_force_se_a = P.ProdForceSeA()
|
||||
|
||||
def construct(self, d_coord, d_nlist, frames, avg, std, atype, nlist):
|
||||
"""construct function."""
|
||||
rij, descrpt, descrpt_deriv = \
|
||||
_, descrpt, _ = \
|
||||
self.descrpt_se_a(d_coord, d_nlist, frames, avg, std, atype)
|
||||
# calculate energy and atom_ener
|
||||
atom_ener = self.mdnet(descrpt)
|
||||
|
@ -244,10 +241,4 @@ class Network(nn.Cell):
|
|||
energy = self.sum(energy_raw, 1)
|
||||
# grad of atom_ener
|
||||
net_deriv = self.grad(descrpt)
|
||||
net_deriv_reshape = self.reshape(net_deriv[0], (-1, natoms[0], ndescrpt))
|
||||
descrpt_deriv_reshape = self.reshape(descrpt_deriv, (-1, natoms[0], ndescrpt, 3))
|
||||
# calculate virial
|
||||
virial = self.prod_virial_se_a(net_deriv_reshape, descrpt_deriv_reshape, rij, nlist)
|
||||
# calculate force
|
||||
force = self.prod_force_se_a(net_deriv_reshape, descrpt_deriv_reshape, nlist)
|
||||
return energy, atom_ener, force, virial
|
||||
return energy, atom_ener, net_deriv
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Calculate virial of atoms."""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ProdVirialSeA(nn.Cell):
|
||||
"""calculate virial."""
|
||||
def __init__(self):
|
||||
super(ProdVirialSeA, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.rsum = P.ReduceSum()
|
||||
self.rksum = P.ReduceSum(keep_dims=True)
|
||||
self.broadcastto1 = P.BroadcastTo((1, 192, 138, 4, 3, 3))
|
||||
self.broadcastto2 = P.BroadcastTo((1, 192, 138, 4, 3))
|
||||
self.broadcastto3 = P.BroadcastTo((1, 192, 138, 3))
|
||||
self.expdims = P.ExpandDims()
|
||||
|
||||
def construct(self, net_deriv_reshape, descrpt_deriv, rij, nlist):
|
||||
"""construct function."""
|
||||
descrpt_deriv = self.cast(descrpt_deriv, mstype.float32)
|
||||
descrpt_deriv = self.reshape(descrpt_deriv, (1, 192, 138, 4, 3))
|
||||
|
||||
net_deriv_reshape = self.cast(net_deriv_reshape, mstype.float32)
|
||||
net_deriv_reshape = self.reshape(net_deriv_reshape, (1, 192, 138, 4))
|
||||
net_deriv_reshape = self.expdims(net_deriv_reshape, 4)
|
||||
net_deriv_reshape = self.broadcastto2(net_deriv_reshape)
|
||||
|
||||
rij = self.cast(rij, mstype.float32)
|
||||
rij = self.reshape(rij, (1, 192, 138, 3))
|
||||
rij = self.expdims(rij, 3)
|
||||
rij = self.expdims(rij, 4)
|
||||
rij = self.broadcastto1(rij)
|
||||
|
||||
nlist = self.cast(nlist, mstype.int32)
|
||||
nlist = self.reshape(nlist, (1, 192, 138))
|
||||
nlist = self.expdims(nlist, 3)
|
||||
nlist = self.broadcastto3(nlist)
|
||||
|
||||
tmp = descrpt_deriv * net_deriv_reshape
|
||||
|
||||
b_blist = self.cast(nlist > -1, mstype.int32)
|
||||
b_blist = self.expdims(b_blist, 3)
|
||||
b_blist = self.broadcastto2(b_blist)
|
||||
|
||||
tmp_1 = tmp * b_blist
|
||||
tmp_1 = self.expdims(tmp_1, 5)
|
||||
tmp_1 = self.broadcastto1(tmp_1)
|
||||
|
||||
out = tmp_1 * rij
|
||||
out = self.rsum(out, (1, 2, 3))
|
||||
return out
|
Loading…
Reference in New Issue