forked from mindspore-Ecosystem/mindspore
update force calculation in MD
This commit is contained in:
parent
9cfaba8983
commit
4d3602514f
|
@ -43,5 +43,6 @@ from .select import _select_akg
|
|||
from .sqrt import _sqrt_akg
|
||||
from .square import _square_akg
|
||||
from .sub import _sub_akg
|
||||
from .prod_force_se_a import _prod_force_se_a_akg
|
||||
|
||||
# Please insert op register in lexicographical order of the filename.
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ProdForceSeA op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
|
||||
|
||||
op_info = AkgAscendRegOp("ProdForceSeA") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.attr("natoms", "required", "int") \
|
||||
.input(0, "net_deriv_tensor") \
|
||||
.input(1, "in_deriv_tensor") \
|
||||
.input(2, "nlist_tensor") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DT.F32_Default, DT.F32_Default, DT.I32_Default, DT.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(op_info)
|
||||
def _prod_force_se_a_akg():
|
||||
"""ProdForceSeA Akg register"""
|
||||
return
|
|
@ -88,7 +88,8 @@ from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingB
|
|||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle)
|
||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle,
|
||||
ProdForceSeA)
|
||||
from .sparse_ops import SparseToDense
|
||||
from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter
|
||||
|
||||
|
|
|
@ -709,3 +709,22 @@ class DetTriangle(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x1_dtype):
|
||||
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
|
||||
return x1_dtype
|
||||
|
||||
|
||||
class ProdForceSeA(PrimitiveWithInfer):
|
||||
"""
|
||||
ProdForceSeA.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, natoms=192):
|
||||
self.init_prim_io_names(inputs=['net_deriv_tensor', "in_deriv_tensor", "nlist_tensor"], outputs=['y'])
|
||||
self.natoms = natoms
|
||||
self.add_prim_attr('natoms', self.natoms)
|
||||
|
||||
def infer_shape(self, x1_shape, x2_shape, x3_shape):
|
||||
out_shape = [x3_shape[0], x3_shape[1], 3]
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
|
||||
return x1_dtype
|
||||
|
|
|
@ -31,9 +31,9 @@ The overall network architecture of MD simulation is show below.
|
|||
|
||||
Dataset used: deepmodeling/deepmd-kit/examples/water/data
|
||||
|
||||
The data is generated by Quantum Espresso and the input of Quantum Espresso is setted manually.
|
||||
The data is generated by Quantum Espresso and the input of Quantum Espresso is set manually.
|
||||
|
||||
The directory structure of the data is as follows:
|
||||
The directory structure of the dataset is as follows:
|
||||
|
||||
```text
|
||||
└─data
|
||||
|
@ -50,10 +50,10 @@ The directory structure of the data is as follows:
|
|||
|
||||
In `deepmodeling/deepmd-kit/source`:
|
||||
|
||||
- Use `train/DataSystem.py` to get coord and atype.
|
||||
- Use `train/DataSystem.py` to get d_coord and atype.
|
||||
- Use function compute_input_stats in `train/DataSystem.py` to get avg and std.
|
||||
- Use `op/descrpt_se_a.cc` to get nlist.
|
||||
- Save coord, atype, avg, std and nlist as `Npz` file for infer.
|
||||
- Use `op/descrpt_se_a.cc` to get d_nlist and nlist.
|
||||
- Save d_coord, d_nlist, atype, avg, std and nlist as `Npz` file for inference.
|
||||
|
||||
## Environment Requirements
|
||||
|
||||
|
@ -96,12 +96,13 @@ python eval.py --dataset_path [DATASET_PATH] --checkpoint_path [CHECKPOINT_PATH]
|
|||
|
||||
### Result
|
||||
|
||||
推理的结果如下:
|
||||
The infer result:
|
||||
|
||||
```text
|
||||
atom_ener: -94.38766 -94.294426 -94.39194 -94.70758 -94.51311 -94.457954 ...
|
||||
force: 1.64911175 -1.09822524 0.46055657 -1.34915102 -0.33827361 -0.97184098 ...
|
||||
virial: -11.736662 -4.286214 2.8852937 -4.286209 -10.408775 -5.6738234 ...
|
||||
energy: -29944.03
|
||||
atom_energy: -94.38766 -94.294426 -94.39194 -94.70758 -94.51311 -94.457954 ...
|
||||
force: 1.649112 -1.0982257 0.46055675 -1.3491507 -0.3382736 -0.97184074 ...
|
||||
virial: -11.736662 -4.2862144 2.8852935 -4.286209 -10.408775 -5.6738224 ...
|
||||
```
|
||||
|
||||
## ModelZoo Homepage
|
||||
|
|
|
@ -32,11 +32,12 @@ context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="A
|
|||
if __name__ == '__main__':
|
||||
# get input data
|
||||
r = np.load(args_opt.dataset_path)
|
||||
d_coord, d_nlist, avg, std, atype = r['d_coord'], r['d_nlist'], r['avg'], r['std'], r['atype']
|
||||
d_coord, d_nlist, avg, std, atype, nlist = r['d_coord'], r['d_nlist'], r['avg'], r['std'], r['atype'], r['nlist']
|
||||
batch_size = 1
|
||||
atype_tensor = Tensor(atype)
|
||||
avg_tensor = Tensor(avg)
|
||||
std_tensor = Tensor(std)
|
||||
nlist_tensor = Tensor(nlist)
|
||||
d_coord_tensor = Tensor(np.reshape(d_coord, (1, -1, 3)))
|
||||
d_nlist_tensor = Tensor(d_nlist)
|
||||
frames = []
|
||||
|
@ -48,8 +49,9 @@ if __name__ == '__main__':
|
|||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.to_float(mstype.float32)
|
||||
energy, atom_ener, virial = \
|
||||
net(d_coord_tensor, d_nlist_tensor, frames, avg_tensor, std_tensor, atype_tensor)
|
||||
print(energy)
|
||||
print(atom_ener)
|
||||
print(virial)
|
||||
energy, atom_ener, force, virial = \
|
||||
net(d_coord_tensor, d_nlist_tensor, frames, avg_tensor, std_tensor, atype_tensor, nlist_tensor)
|
||||
print('energy:', energy)
|
||||
print('atom_energy:', atom_ener)
|
||||
print('force:', force)
|
||||
print('virial:', virial)
|
||||
|
|
|
@ -22,8 +22,8 @@ from mindspore import Tensor
|
|||
from mindspore import nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from virial import ProdVirialSeA
|
||||
from descriptor import DescriptorSeA
|
||||
from .virial import ProdVirialSeA
|
||||
from .descriptor import DescriptorSeA
|
||||
|
||||
natoms = [192, 192, 64, 128]
|
||||
rcut_a = -1
|
||||
|
@ -231,11 +231,12 @@ class Network(nn.Cell):
|
|||
self.descrpt_se_a = DescriptorSeA()
|
||||
self.process = Processing()
|
||||
self.prod_virial_se_a = ProdVirialSeA()
|
||||
self.prod_force_se_a = P.ProdForceSeA()
|
||||
|
||||
def construct(self, coord, nlist, frames, avg, std, atype):
|
||||
def construct(self, d_coord, d_nlist, frames, avg, std, atype, nlist):
|
||||
"""construct function."""
|
||||
rij, descrpt, descrpt_deriv = \
|
||||
self.descrpt_se_a(coord, nlist, frames, avg, std, atype)
|
||||
self.descrpt_se_a(d_coord, d_nlist, frames, avg, std, atype)
|
||||
# calculate energy and atom_ener
|
||||
atom_ener = self.mdnet(descrpt)
|
||||
energy_raw = atom_ener
|
||||
|
@ -247,4 +248,6 @@ class Network(nn.Cell):
|
|||
descrpt_deriv_reshape = self.reshape(descrpt_deriv, (-1, natoms[0], ndescrpt, 3))
|
||||
# calculate virial
|
||||
virial = self.prod_virial_se_a(net_deriv_reshape, descrpt_deriv_reshape, rij, nlist)
|
||||
return energy, atom_ener, virial
|
||||
# calculate force
|
||||
force = self.prod_force_se_a(net_deriv_reshape, descrpt_deriv_reshape, nlist)
|
||||
return energy, atom_ener, force, virial
|
||||
|
|
Loading…
Reference in New Issue