!15331 fix sponge case master 0417

From: @jiahongqian
Reviewed-by: @wang_zi_dong,@ljl0711
Signed-off-by: @ljl0711
This commit is contained in:
mindspore-ci-bot 2021-04-20 22:03:02 +08:00 committed by Gitee
commit d35939603a
10 changed files with 468 additions and 154 deletions

View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
__device__ __host__ float fc(float Rij) {
const float PI = 3.141592654;
const float Rc = 1000.0;
return 0.5 * cosf(PI / Rc * Rij) + 0.5;
}
__global__ void Record_Box_Map_Times(int atom_numbers, const float *crd, const float *old_crd, float *box,
int *box_map_times) {
float half_box[3] = {0.5 * box[0], 0.5 * box[1], 0.5 * box[2]};
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < atom_numbers) {
if (crd[3 * i + 0] - old_crd[3 * i + 0] > half_box[0]) {
box_map_times[3 * i + 0] = box_map_times[3 * i + 0] - 1;
} else if (crd[3 * i + 0] - old_crd[3 * i + 0] < -half_box[0]) {
box_map_times[3 * i + 0] = box_map_times[3 * i + 0] + 1;
}
if (crd[3 * i + 1] - old_crd[3 * i + 1] > half_box[1]) {
box_map_times[3 * i + 1] = box_map_times[3 * i + 1] - 1;
} else if (crd[3 * i + 1] - old_crd[3 * i + 1] < -half_box[1]) {
box_map_times[3 * i + 1] = box_map_times[3 * i + 1] + 1;
}
if (crd[3 * i + 2] - old_crd[3 * i + 2] > half_box[2]) {
box_map_times[3 * i + 2] = box_map_times[3 * i + 2] - 1;
} else if (crd[3 * i + 2] - old_crd[3 * i + 2] < -half_box[2]) {
box_map_times[3 * i + 2] = box_map_times[3 * i + 2] + 1;
}
}
}
__global__ void gen_nowarp_crd(int atom_numbers, const float *crd, float *box, int *box_map_times, float *nowarp_crd) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < atom_numbers) {
nowarp_crd[3 * i + 0] = static_cast<float>(box_map_times[3 * i + 0]) * box[0] + crd[3 * i + 0];
nowarp_crd[3 * i + 1] = static_cast<float>(box_map_times[3 * i + 1]) * box[1] + crd[3 * i + 1];
nowarp_crd[3 * i + 2] = static_cast<float>(box_map_times[3 * i + 2]) * box[2] + crd[3 * i + 2];
}
}
__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) {
const float Rs = 0.5, Eta = 0.5;
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= start_serial && i < end_serial) {
float rij;
float g_radial_lin = 0.;
for (int j = start_serial; j < end_serial; j = j + 1) {
if (j != i) {
// rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j]));
rij = sqrtf(normfloat(crd, crd, i, j));
g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij);
} else {
continue;
}
}
g_radial[i] = g_radial_lin;
}
}
__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) {
const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0;
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= start_serial && i < end_serial) {
float rij, rik, rjk, theta_jik;
float g_angular_lin = 0.;
for (int j = start_serial; j < end_serial; j = j + 1) {
if (j != i) {
rij = sqrtf(normfloat(crd, crd, i, j));
for (int k = j + 1; k < end_serial; k = k + 1) {
if (k != i) {
rik = sqrtf(normfloat(crd, crd, i, k));
rjk = sqrtf(normfloat(crd, crd, j, k));
theta_jik =
acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999));
g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) *
expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik);
} else {
continue;
}
}
} else {
continue;
}
}
g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin;
}
}
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
float *g_angular, cudaStream_t stream) {
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, box_map_times,
0);
Record_Box_Map_Times<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, crd_f, old_crd, box, box_map_times);
gen_nowarp_crd<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, crd_f, box,
box_map_times, nowarp_crd);
G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_radial);
G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_angular);
return;
}
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
float *g_angular, cudaStream_t stream);

View File

@ -14,12 +14,13 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
cudaStream_t stream);
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
float *g_angular, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_

View File

@ -1,83 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
__device__ __host__ float fc(float Rij) {
const float PI = 3.141592654;
const float Rc = 1000.0;
return 0.5 * cosf(PI / Rc * Rij) + 0.5;
}
__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) {
const float Rs = 0.5, Eta = 0.5;
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= start_serial && i < end_serial) {
float rij;
float g_radial_lin = 0.;
for (int j = start_serial; j < end_serial; j = j + 1) {
if (j != i) {
// rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j]));
rij = sqrtf(normfloat(crd, crd, i, j));
g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij);
} else {
continue;
}
}
g_radial[i] = g_radial_lin;
}
}
__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) {
const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0;
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= start_serial && i < end_serial) {
float rij, rik, rjk, theta_jik;
float g_angular_lin = 0.;
for (int j = start_serial; j < end_serial; j = j + 1) {
if (j != i) {
rij = sqrtf(normfloat(crd, crd, i, j));
for (int k = j + 1; k < end_serial; k = k + 1) {
if (k != i) {
rik = sqrtf(normfloat(crd, crd, i, k));
rjk = sqrtf(normfloat(crd, crd, j, k));
theta_jik =
acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999));
g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) *
expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik);
} else {
continue;
}
}
} else {
continue;
}
}
g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin;
}
}
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
cudaStream_t stream) {
G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_radial);
G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_angular);
return;
}
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
cudaStream_t stream);

View File

@ -14,13 +14,19 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h"
#include "backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
TransferCrd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TransferGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(TransferCrd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
AtomCrdToCVGpuKernel, float, int)
} // namespace kernel
} // namespace mindspore

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh"
#include <cuda_runtime_api.h>
#include <map>
@ -31,19 +31,24 @@
namespace mindspore {
namespace kernel {
template <typename T, typename T1>
class TransferGpuKernel : public GpuKernel {
class AtomCrdToCVGpuKernel : public GpuKernel {
public:
TransferGpuKernel() : ele_crd(1) {}
~TransferGpuKernel() override = default;
AtomCrdToCVGpuKernel() : ele_crd(1) {}
~AtomCrdToCVGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
start_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "start_serial"));
end_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "end_serial"));
number = static_cast<int>(GetAttr<int64_t>(kernel_node, "number"));
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape_old_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape_box = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i];
for (size_t i = 0; i < shape_old_crd.size(); i++) ele_old_crd *= shape_old_crd[i];
for (size_t i = 0; i < shape_box.size(); i++) ele_box *= shape_box[i];
InitSizeLists();
return true;
@ -53,27 +58,40 @@ class TransferGpuKernel : public GpuKernel {
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto crd = GetDeviceAddress<const T>(inputs, 0);
auto old_crd = GetDeviceAddress<const T>(inputs, 1);
auto box = GetDeviceAddress<T>(inputs, 2);
auto g_radial = GetDeviceAddress<T>(outputs, 0);
auto g_angular = GetDeviceAddress<T>(outputs, 1);
Transfer(start_serial, end_serial, number, crd, g_radial, g_angular, reinterpret_cast<cudaStream_t>(stream_ptr));
auto nowarp_crd = GetDeviceAddress<T>(outputs, 2);
auto box_map_times = GetDeviceAddress<T1>(outputs, 3);
AtomCrdToCV(atom_numbers, start_serial, end_serial, number, crd, old_crd, nowarp_crd, box_map_times, box, g_radial,
g_angular, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(ele_crd * sizeof(T));
input_size_list_.push_back(ele_old_crd * sizeof(T));
input_size_list_.push_back(ele_box * sizeof(T));
output_size_list_.push_back(number * sizeof(T));
output_size_list_.push_back(number * sizeof(T));
output_size_list_.push_back(3 * atom_numbers * sizeof(T));
output_size_list_.push_back(3 * atom_numbers * sizeof(T1));
}
private:
size_t ele_crd = 1;
size_t ele_old_crd = 1;
size_t ele_box = 1;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
@ -81,7 +99,8 @@ class TransferGpuKernel : public GpuKernel {
int end_serial;
int start_serial;
int number;
int atom_numbers;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_

View File

@ -2852,33 +2852,54 @@ class TransferCrd(PrimitiveWithInfer):
Inputs:
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom.
N is the number of atoms..
- **old_crd** (Tensor, float32) - [N, 3], the last coordinate of each atom.
N is the number of atoms.
- **box** (Tensor, float32) - [3,], the length of 3 dimensions of the simulation box.
Outputs:
- **output** (uint32)
- **radial** (Tensor, float32) - [number,], the array of radial transferred from coordinates.
- **angular** (Tensor, float32) - [number,], the array of angular transferred from coordinates.
- **nowarp_crd** (Tensor, float32) - [N, 3], the modified coordinate of each atom for
computing radial and angular.
- **box_map_times** (Tensor, int32) - [N, 3], the box map times for radial and angular.
Supported Platforms:
``GPU``
"""
@prim_attr_register
def __init__(self, start_serial, end_serial, number):
def __init__(self, start_serial, end_serial, number, atom_numbers):
validator.check_value_type('start_serial', start_serial, (int), self.name)
validator.check_value_type('end_serial', end_serial, (int), self.name)
validator.check_value_type('number', number, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
self.start_serial = start_serial
self.end_serial = end_serial
self.number = number
self.atom_numbers = atom_numbers
self.add_prim_attr('start_serial', self.start_serial)
self.add_prim_attr('end_serial', self.end_serial)
self.add_prim_attr('number', self.number)
self.add_prim_attr('atom_numbers', self.atom_numbers)
self.init_prim_io_names(
inputs=['crd'],
outputs=['radial', 'angular'])
inputs=['crd', 'old_crd', 'box'],
outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times'])
def infer_shape(self, crd_shape):
def infer_shape(self, crd_shape, old_crd_shape, box_shape):
N = self.atom_numbers
validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name)
return [self.number,], [self.number,]
validator.check_int(crd_shape[0], N, Rel.EQ, "crd_shape[0]", self.name)
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name)
validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
validator.check_int(old_crd_shape[0], N, Rel.EQ, "old_crd_shape[0]", self.name)
validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name)
validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name)
validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name)
return [self.number,], [self.number,], [N, 3], [N, 3]
def infer_dtype(self, crd_dtype):
def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype):
validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
return mstype.float32, mstype.float32
validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
return mstype.float32, mstype.float32, mstype.float32, mstype.int32

View File

@ -17,73 +17,57 @@ import argparse
import time
from src.simulation import Simulation
from src.mdnn import Mdnn, TransCrdToCV
import mindspore.context as context
from mindspore import Tensor
from mindspore import load_checkpoint
parser = argparse.ArgumentParser(description='Sponge Controller')
parser.add_argument('--i', type=str, default=None, help='input file')
parser.add_argument('--amber_parm', type=str, default=None, help='paramter file in AMBER type')
parser.add_argument('--c', type=str, default=None, help='initial coordinates file')
parser = argparse.ArgumentParser(description='SPONGE Controller')
parser.add_argument('--i', type=str, default=None, help='Input file')
parser.add_argument('--amber_parm', type=str, default=None, help='Paramter file in AMBER type')
parser.add_argument('--c', type=str, default=None, help='Initial coordinates file')
parser.add_argument('--r', type=str, default="restrt", help='')
parser.add_argument('--x', type=str, default="mdcrd", help='')
parser.add_argument('--o', type=str, default="mdout", help="")
parser.add_argument('--o', type=str, default="mdout", help='Output file')
parser.add_argument('--box', type=str, default="mdbox", help='')
parser.add_argument('--device_id', type=int, default=0, help='')
parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge')
parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
if __name__ == "__main__":
simulation = Simulation(args_opt)
if args_opt.u and args_opt.checkpoint:
net = Mdnn()
load_checkpoint(args_opt.checkpoint, net=net)
transcrd = TransCrdToCV(simulation)
start = time.time()
compiler_time = 0
save_path = args_opt.o
file = open(save_path, 'w')
simulation.Main_Initial()
for steps in range(simulation.md_info.step_limit):
print_step = steps % simulation.ntwx
if steps == simulation.md_info.step_limit - 1:
print_step = 0
temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step))
if steps == 0:
compiler_time = time.time()
if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1:
if steps == 0:
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
simulation.Main_Print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene,
sigma_of_dihedral_ene, nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene)
temperature = temperature.asnumpy()
total_potential_energy = total_potential_energy.asnumpy()
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)),
end=" ")
if simulation.bond.bond_numbers > 0:
sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
if simulation.angle.angle_numbers > 0:
sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
if simulation.dihedral.dihedral_numbers > 0:
sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
if simulation.nb14.nb14_numbers > 0:
nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
LJ_energy_sum = LJ_energy_sum.asnumpy()
ee_ene = ee_ene.asnumpy()
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
print("{:>12.3f}".format(float(ee_ene)))
if file is not None:
file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
if args_opt.u and args_opt.checkpoint and steps % (4 * simulation.ntwx) == 0:
print("Update charge!")
inputs = transcrd(Tensor(simulation.crd), Tensor(simulation.last_crd))
t_charge = net(inputs)
simulation.charge = transcrd.updatecharge(t_charge)
end = time.time()
file.close()
print("Main time(s):", end - start)
print("Main time(s) without compiler:", end - compiler_time)
simulation.Main_Destroy()

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mdnn class"""
import numpy as np
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
class Mdnn(nn.Cell):
"""Mdnn"""
def __init__(self, dim=258, dr=0.5):
super(Mdnn, self).__init__()
self.dim = dim
self.dr = dr # dropout_ratio
self.fc1 = nn.Dense(dim, 512)
self.fc2 = nn.Dense(512, 512)
self.fc3 = nn.Dense(512, 512)
self.fc4 = nn.Dense(512, 129)
self.tanh = nn.Tanh()
def construct(self, x):
"""construct"""
x = self.tanh(self.fc1(x))
x = self.tanh(self.fc2(x))
x = self.tanh(self.fc3(x))
x = self.fc4(x)
return x
class TransCrdToCV(nn.Cell):
"""TransCrdToCV"""
def __init__(self, simulation):
super(TransCrdToCV, self).__init__()
self.atom_numbers = simulation.atom_numbers
self.transfercrd = P.TransferCrd(0, 129, 129, self.atom_numbers)
self.box = Tensor(simulation.box_length)
self.radial = Parameter(Tensor(np.zeros([129,]), mstype.float32))
self.angular = Parameter(Tensor(np.zeros([129,]), mstype.float32))
self.output = Parameter(Tensor(np.zeros([1, 258]), mstype.float32))
self.charge = simulation.charge
def updatecharge(self, t_charge):
"""update charge in simulation"""
self.charge[:129] = t_charge[0] * 18.2223
return self.charge
def construct(self, crd, last_crd):
"""construct"""
self.radial, self.angular, _, _ = self.transfercrd(crd, last_crd, self.box)
self.output = P.Concat()((self.radial, self.angular))
self.output = P.ExpandDims()(self.output, 0)
return self.output

View File

@ -34,6 +34,7 @@ from src.particle_mesh_ewald import Particle_Mesh_Ewald
class controller:
'''controller'''
def __init__(self, args_opt):
self.input_file = args_opt.i
self.initial_coordinates_file = args_opt.c
@ -67,6 +68,7 @@ class controller:
class Simulation(nn.Cell):
'''simulation'''
def __init__(self, args_opt):
super(Simulation, self).__init__()
self.control = controller(args_opt)
@ -119,6 +121,7 @@ class Simulation(nn.Cell):
self.exp_gamma = self.liujian_info.exp_gamma
self.init_Tensor()
self.op_define()
self.update = False
def init_Tensor(self):
'''init tensor'''
@ -129,9 +132,12 @@ class Simulation(nn.Cell):
self.uint_dr_to_dr_cof = Parameter(
Tensor(np.asarray(self.md_info.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False)
self.box_length = Tensor(self.md_info.box_length, mstype.float32)
self.charge = Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32)
self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32),
requires_grad=False)
self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
requires_grad=False)
self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
requires_grad=False)
self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32),
requires_grad=False)
self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32)
@ -341,8 +347,65 @@ class Simulation(nn.Cell):
acc = F.depend(self.acc, crd)
return vel, crd, acc
def Main_Print(self, *args):
"""compute the temperature"""
steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene = list(args)
if steps == 0:
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
temperature = temperature.asnumpy()
total_potential_energy = total_potential_energy.asnumpy()
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)),
end=" ")
if self.bond.bond_numbers > 0:
sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
if self.angle.angle_numbers > 0:
sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
if self.dihedral.dihedral_numbers > 0:
sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
if self.nb14.nb14_numbers > 0:
nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
LJ_energy_sum = LJ_energy_sum.asnumpy()
ee_ene = ee_ene.asnumpy()
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
print("{:>12.3f}".format(float(ee_ene)))
if self.file is not None:
self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
if self.datfile is not None:
self.datfile.write(self.crd.asnumpy())
def Main_Initial(self):
"""main initial"""
if self.control.mdout:
self.file = open(self.control.mdout, 'w')
self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
if self.control.mdcrd:
self.datfile = open(self.control.mdcrd, 'wb')
def Main_Destroy(self):
"""main destroy"""
if self.file is not None:
self.file.close()
print("Save .out file successfully!")
if self.datfile is not None:
self.datfile.close()
print("Save .dat file successfully!")
def construct(self, step, print_step):
'''construct'''
self.last_crd = self.crd
if step == 0:
res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
self.box_length, self.grid_N, self.grid_length_inverse,

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
import argparse
import numpy as np
from src.mdnn import Mdnn
from mindspore import nn, Model, context
from mindspore import dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.callback import Callback
import mindspore.common.initializer as weight_init
parser = argparse.ArgumentParser(description='Mdnn Controller')
parser.add_argument('--i', type=str, default=None, help='Input radial and angular dat file')
parser.add_argument('--charge', type=str, default=None, help='Input charge dat file')
parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
class StepLossAccInfo(Callback):
"""custom callback function"""
def __init__(self, models, eval_dataset, steploss):
"""init model"""
self.model = models
self.eval_dataset = eval_dataset
self.steps_loss = steploss
def step_end(self, run_context):
"""step end"""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num
self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
self.steps_loss["step"].append(str(cur_step))
def get_data(inputdata, outputdata):
"""get data function"""
for _, data in enumerate(zip(inputdata, outputdata)):
yield data
def create_dataset(inputdata, outputdata, batchsize=32, repeat_size=1):
"""create dataset function"""
input_data = ds.GeneratorDataset(list(get_data(inputdata, outputdata)), column_names=['data', 'label'])
input_data = input_data.batch(batchsize)
input_data = input_data.repeat(repeat_size)
return input_data
def init_weight(nnet):
for _, cell in nnet.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if __name__ == '__main__':
# read input files
inputs = args_opt.i
outputs = args_opt.charge
radial_angular = np.fromfile(inputs, dtype=np.float32)
radial_angular = radial_angular.reshape((-1, 258)).astype(np.float32)
charge = np.fromfile(outputs, dtype=np.float32)
charge = charge.reshape((-1, 129)).astype(np.float32)
# define the model
net = Mdnn()
lr = 0.0001
decay_rate = 0.8
epoch_size = 1000
batch_size = 500
total_step = epoch_size * batch_size
step_per_epoch = 100
decay_epoch = epoch_size
lr_rate = nn.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
net_loss = nn.loss.MSELoss(reduction='mean')
net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_rate)
model = Model(net, net_loss, net_opt)
ds_train = create_dataset(radial_angular, charge, batchsize=batch_size)
model_params = net.trainable_params()
net.set_train()
init_weight(net)
# config files
path = './params/'
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="mdnn_best", directory=path, config=config_ck)
steps_loss = {"step": [], "loss_value": []}
step_loss_acc_info = StepLossAccInfo(model, ds_train, steps_loss)
# train the model
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(100)])