forked from mindspore-Ecosystem/mindspore
!15331 fix sponge case master 0417
From: @jiahongqian Reviewed-by: @wang_zi_dong,@ljl0711 Signed-off-by: @ljl0711
This commit is contained in:
commit
d35939603a
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
|
||||
__device__ __host__ float fc(float Rij) {
|
||||
const float PI = 3.141592654;
|
||||
const float Rc = 1000.0;
|
||||
return 0.5 * cosf(PI / Rc * Rij) + 0.5;
|
||||
}
|
||||
|
||||
__global__ void Record_Box_Map_Times(int atom_numbers, const float *crd, const float *old_crd, float *box,
|
||||
int *box_map_times) {
|
||||
float half_box[3] = {0.5 * box[0], 0.5 * box[1], 0.5 * box[2]};
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < atom_numbers) {
|
||||
if (crd[3 * i + 0] - old_crd[3 * i + 0] > half_box[0]) {
|
||||
box_map_times[3 * i + 0] = box_map_times[3 * i + 0] - 1;
|
||||
} else if (crd[3 * i + 0] - old_crd[3 * i + 0] < -half_box[0]) {
|
||||
box_map_times[3 * i + 0] = box_map_times[3 * i + 0] + 1;
|
||||
}
|
||||
if (crd[3 * i + 1] - old_crd[3 * i + 1] > half_box[1]) {
|
||||
box_map_times[3 * i + 1] = box_map_times[3 * i + 1] - 1;
|
||||
} else if (crd[3 * i + 1] - old_crd[3 * i + 1] < -half_box[1]) {
|
||||
box_map_times[3 * i + 1] = box_map_times[3 * i + 1] + 1;
|
||||
}
|
||||
if (crd[3 * i + 2] - old_crd[3 * i + 2] > half_box[2]) {
|
||||
box_map_times[3 * i + 2] = box_map_times[3 * i + 2] - 1;
|
||||
} else if (crd[3 * i + 2] - old_crd[3 * i + 2] < -half_box[2]) {
|
||||
box_map_times[3 * i + 2] = box_map_times[3 * i + 2] + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gen_nowarp_crd(int atom_numbers, const float *crd, float *box, int *box_map_times, float *nowarp_crd) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < atom_numbers) {
|
||||
nowarp_crd[3 * i + 0] = static_cast<float>(box_map_times[3 * i + 0]) * box[0] + crd[3 * i + 0];
|
||||
nowarp_crd[3 * i + 1] = static_cast<float>(box_map_times[3 * i + 1]) * box[1] + crd[3 * i + 1];
|
||||
nowarp_crd[3 * i + 2] = static_cast<float>(box_map_times[3 * i + 2]) * box[2] + crd[3 * i + 2];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) {
|
||||
const float Rs = 0.5, Eta = 0.5;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij;
|
||||
float g_radial_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
// rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j]));
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_radial[i] = g_radial_lin;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) {
|
||||
const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij, rik, rjk, theta_jik;
|
||||
float g_angular_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
for (int k = j + 1; k < end_serial; k = k + 1) {
|
||||
if (k != i) {
|
||||
rik = sqrtf(normfloat(crd, crd, i, k));
|
||||
rjk = sqrtf(normfloat(crd, crd, j, k));
|
||||
theta_jik =
|
||||
acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999));
|
||||
g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) *
|
||||
expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin;
|
||||
}
|
||||
}
|
||||
|
||||
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
|
||||
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
|
||||
float *g_angular, cudaStream_t stream) {
|
||||
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, box_map_times,
|
||||
0);
|
||||
Record_Box_Map_Times<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(
|
||||
atom_numbers, crd_f, old_crd, box, box_map_times);
|
||||
gen_nowarp_crd<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, crd_f, box,
|
||||
box_map_times, nowarp_crd);
|
||||
G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_radial);
|
||||
G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_angular);
|
||||
return;
|
||||
}
|
||||
|
||||
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
|
||||
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
|
||||
float *g_angular, cudaStream_t stream);
|
|
@ -14,12 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream);
|
||||
void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f,
|
||||
const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial,
|
||||
float *g_angular, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_
|
|
@ -1,83 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
|
||||
__device__ __host__ float fc(float Rij) {
|
||||
const float PI = 3.141592654;
|
||||
const float Rc = 1000.0;
|
||||
return 0.5 * cosf(PI / Rc * Rij) + 0.5;
|
||||
}
|
||||
|
||||
__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) {
|
||||
const float Rs = 0.5, Eta = 0.5;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij;
|
||||
float g_radial_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
// rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j]));
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_radial[i] = g_radial_lin;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) {
|
||||
const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij, rik, rjk, theta_jik;
|
||||
float g_angular_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
for (int k = j + 1; k < end_serial; k = k + 1) {
|
||||
if (k != i) {
|
||||
rik = sqrtf(normfloat(crd, crd, i, k));
|
||||
rjk = sqrtf(normfloat(crd, crd, j, k));
|
||||
theta_jik =
|
||||
acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999));
|
||||
g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) *
|
||||
expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin;
|
||||
}
|
||||
}
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream) {
|
||||
G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_radial);
|
||||
G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_angular);
|
||||
return;
|
||||
}
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream);
|
|
@ -14,13 +14,19 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
TransferCrd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TransferGpuKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TransferCrd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
AtomCrdToCVGpuKernel, float, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <map>
|
||||
|
@ -31,19 +31,24 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename T1>
|
||||
class TransferGpuKernel : public GpuKernel {
|
||||
class AtomCrdToCVGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TransferGpuKernel() : ele_crd(1) {}
|
||||
~TransferGpuKernel() override = default;
|
||||
AtomCrdToCVGpuKernel() : ele_crd(1) {}
|
||||
~AtomCrdToCVGpuKernel() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
start_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "start_serial"));
|
||||
end_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "end_serial"));
|
||||
number = static_cast<int>(GetAttr<int64_t>(kernel_node, "number"));
|
||||
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
|
||||
auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto shape_old_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto shape_box = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i];
|
||||
for (size_t i = 0; i < shape_old_crd.size(); i++) ele_old_crd *= shape_old_crd[i];
|
||||
for (size_t i = 0; i < shape_box.size(); i++) ele_box *= shape_box[i];
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -53,27 +58,40 @@ class TransferGpuKernel : public GpuKernel {
|
|||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto crd = GetDeviceAddress<const T>(inputs, 0);
|
||||
auto old_crd = GetDeviceAddress<const T>(inputs, 1);
|
||||
auto box = GetDeviceAddress<T>(inputs, 2);
|
||||
|
||||
auto g_radial = GetDeviceAddress<T>(outputs, 0);
|
||||
auto g_angular = GetDeviceAddress<T>(outputs, 1);
|
||||
|
||||
Transfer(start_serial, end_serial, number, crd, g_radial, g_angular, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
auto nowarp_crd = GetDeviceAddress<T>(outputs, 2);
|
||||
auto box_map_times = GetDeviceAddress<T1>(outputs, 3);
|
||||
|
||||
AtomCrdToCV(atom_numbers, start_serial, end_serial, number, crd, old_crd, nowarp_crd, box_map_times, box, g_radial,
|
||||
g_angular, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(ele_crd * sizeof(T));
|
||||
input_size_list_.push_back(ele_old_crd * sizeof(T));
|
||||
input_size_list_.push_back(ele_box * sizeof(T));
|
||||
|
||||
output_size_list_.push_back(number * sizeof(T));
|
||||
output_size_list_.push_back(number * sizeof(T));
|
||||
|
||||
output_size_list_.push_back(3 * atom_numbers * sizeof(T));
|
||||
output_size_list_.push_back(3 * atom_numbers * sizeof(T1));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t ele_crd = 1;
|
||||
size_t ele_old_crd = 1;
|
||||
size_t ele_box = 1;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
@ -81,7 +99,8 @@ class TransferGpuKernel : public GpuKernel {
|
|||
int end_serial;
|
||||
int start_serial;
|
||||
int number;
|
||||
int atom_numbers;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_
|
|
@ -2852,33 +2852,54 @@ class TransferCrd(PrimitiveWithInfer):
|
|||
|
||||
Inputs:
|
||||
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom.
|
||||
N is the number of atoms..
|
||||
- **old_crd** (Tensor, float32) - [N, 3], the last coordinate of each atom.
|
||||
N is the number of atoms.
|
||||
|
||||
- **box** (Tensor, float32) - [3,], the length of 3 dimensions of the simulation box.
|
||||
|
||||
Outputs:
|
||||
- **output** (uint32)
|
||||
- **radial** (Tensor, float32) - [number,], the array of radial transferred from coordinates.
|
||||
- **angular** (Tensor, float32) - [number,], the array of angular transferred from coordinates.
|
||||
- **nowarp_crd** (Tensor, float32) - [N, 3], the modified coordinate of each atom for
|
||||
computing radial and angular.
|
||||
- **box_map_times** (Tensor, int32) - [N, 3], the box map times for radial and angular.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, start_serial, end_serial, number):
|
||||
def __init__(self, start_serial, end_serial, number, atom_numbers):
|
||||
validator.check_value_type('start_serial', start_serial, (int), self.name)
|
||||
validator.check_value_type('end_serial', end_serial, (int), self.name)
|
||||
validator.check_value_type('number', number, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
self.start_serial = start_serial
|
||||
self.end_serial = end_serial
|
||||
self.number = number
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('start_serial', self.start_serial)
|
||||
self.add_prim_attr('end_serial', self.end_serial)
|
||||
self.add_prim_attr('number', self.number)
|
||||
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
||||
self.init_prim_io_names(
|
||||
inputs=['crd'],
|
||||
outputs=['radial', 'angular'])
|
||||
inputs=['crd', 'old_crd', 'box'],
|
||||
outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times'])
|
||||
|
||||
def infer_shape(self, crd_shape):
|
||||
def infer_shape(self, crd_shape, old_crd_shape, box_shape):
|
||||
N = self.atom_numbers
|
||||
validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
|
||||
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name)
|
||||
return [self.number,], [self.number,]
|
||||
validator.check_int(crd_shape[0], N, Rel.EQ, "crd_shape[0]", self.name)
|
||||
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name)
|
||||
validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name)
|
||||
validator.check_int(old_crd_shape[0], N, Rel.EQ, "old_crd_shape[0]", self.name)
|
||||
validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name)
|
||||
validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name)
|
||||
validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name)
|
||||
return [self.number,], [self.number,], [N, 3], [N, 3]
|
||||
|
||||
def infer_dtype(self, crd_dtype):
|
||||
def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype):
|
||||
validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
|
||||
return mstype.float32, mstype.float32
|
||||
validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
|
||||
return mstype.float32, mstype.float32, mstype.float32, mstype.int32
|
||||
|
|
|
@ -17,73 +17,57 @@ import argparse
|
|||
import time
|
||||
|
||||
from src.simulation import Simulation
|
||||
|
||||
from src.mdnn import Mdnn, TransCrdToCV
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import load_checkpoint
|
||||
|
||||
parser = argparse.ArgumentParser(description='Sponge Controller')
|
||||
parser.add_argument('--i', type=str, default=None, help='input file')
|
||||
parser.add_argument('--amber_parm', type=str, default=None, help='paramter file in AMBER type')
|
||||
parser.add_argument('--c', type=str, default=None, help='initial coordinates file')
|
||||
parser = argparse.ArgumentParser(description='SPONGE Controller')
|
||||
parser.add_argument('--i', type=str, default=None, help='Input file')
|
||||
parser.add_argument('--amber_parm', type=str, default=None, help='Paramter file in AMBER type')
|
||||
parser.add_argument('--c', type=str, default=None, help='Initial coordinates file')
|
||||
parser.add_argument('--r', type=str, default="restrt", help='')
|
||||
parser.add_argument('--x', type=str, default="mdcrd", help='')
|
||||
parser.add_argument('--o', type=str, default="mdout", help="")
|
||||
parser.add_argument('--o', type=str, default="mdout", help='Output file')
|
||||
parser.add_argument('--box', type=str, default="mdbox", help='')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
|
||||
parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge')
|
||||
parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
simulation = Simulation(args_opt)
|
||||
if args_opt.u and args_opt.checkpoint:
|
||||
net = Mdnn()
|
||||
load_checkpoint(args_opt.checkpoint, net=net)
|
||||
transcrd = TransCrdToCV(simulation)
|
||||
|
||||
start = time.time()
|
||||
compiler_time = 0
|
||||
save_path = args_opt.o
|
||||
file = open(save_path, 'w')
|
||||
simulation.Main_Initial()
|
||||
for steps in range(simulation.md_info.step_limit):
|
||||
print_step = steps % simulation.ntwx
|
||||
if steps == simulation.md_info.step_limit - 1:
|
||||
print_step = 0
|
||||
temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
|
||||
nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step))
|
||||
|
||||
if steps == 0:
|
||||
compiler_time = time.time()
|
||||
if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1:
|
||||
if steps == 0:
|
||||
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
|
||||
file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
|
||||
simulation.Main_Print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene,
|
||||
sigma_of_dihedral_ene, nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene)
|
||||
|
||||
temperature = temperature.asnumpy()
|
||||
total_potential_energy = total_potential_energy.asnumpy()
|
||||
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)),
|
||||
end=" ")
|
||||
if simulation.bond.bond_numbers > 0:
|
||||
sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
|
||||
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
|
||||
if simulation.angle.angle_numbers > 0:
|
||||
sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
|
||||
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
|
||||
if simulation.dihedral.dihedral_numbers > 0:
|
||||
sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
|
||||
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
|
||||
if simulation.nb14.nb14_numbers > 0:
|
||||
nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
|
||||
nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
|
||||
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
|
||||
LJ_energy_sum = LJ_energy_sum.asnumpy()
|
||||
ee_ene = ee_ene.asnumpy()
|
||||
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
|
||||
print("{:>12.3f}".format(float(ee_ene)))
|
||||
if file is not None:
|
||||
file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
|
||||
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
|
||||
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
|
||||
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
|
||||
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
|
||||
if args_opt.u and args_opt.checkpoint and steps % (4 * simulation.ntwx) == 0:
|
||||
print("Update charge!")
|
||||
inputs = transcrd(Tensor(simulation.crd), Tensor(simulation.last_crd))
|
||||
t_charge = net(inputs)
|
||||
simulation.charge = transcrd.updatecharge(t_charge)
|
||||
|
||||
end = time.time()
|
||||
file.close()
|
||||
print("Main time(s):", end - start)
|
||||
print("Main time(s) without compiler:", end - compiler_time)
|
||||
simulation.Main_Destroy()
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""mdnn class"""
|
||||
|
||||
import numpy as np
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class Mdnn(nn.Cell):
|
||||
"""Mdnn"""
|
||||
|
||||
def __init__(self, dim=258, dr=0.5):
|
||||
super(Mdnn, self).__init__()
|
||||
self.dim = dim
|
||||
self.dr = dr # dropout_ratio
|
||||
self.fc1 = nn.Dense(dim, 512)
|
||||
self.fc2 = nn.Dense(512, 512)
|
||||
self.fc3 = nn.Dense(512, 512)
|
||||
self.fc4 = nn.Dense(512, 129)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.tanh(self.fc1(x))
|
||||
x = self.tanh(self.fc2(x))
|
||||
x = self.tanh(self.fc3(x))
|
||||
x = self.fc4(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransCrdToCV(nn.Cell):
|
||||
"""TransCrdToCV"""
|
||||
|
||||
def __init__(self, simulation):
|
||||
super(TransCrdToCV, self).__init__()
|
||||
self.atom_numbers = simulation.atom_numbers
|
||||
self.transfercrd = P.TransferCrd(0, 129, 129, self.atom_numbers)
|
||||
self.box = Tensor(simulation.box_length)
|
||||
self.radial = Parameter(Tensor(np.zeros([129,]), mstype.float32))
|
||||
self.angular = Parameter(Tensor(np.zeros([129,]), mstype.float32))
|
||||
self.output = Parameter(Tensor(np.zeros([1, 258]), mstype.float32))
|
||||
self.charge = simulation.charge
|
||||
|
||||
def updatecharge(self, t_charge):
|
||||
"""update charge in simulation"""
|
||||
self.charge[:129] = t_charge[0] * 18.2223
|
||||
return self.charge
|
||||
|
||||
def construct(self, crd, last_crd):
|
||||
"""construct"""
|
||||
self.radial, self.angular, _, _ = self.transfercrd(crd, last_crd, self.box)
|
||||
self.output = P.Concat()((self.radial, self.angular))
|
||||
self.output = P.ExpandDims()(self.output, 0)
|
||||
return self.output
|
|
@ -34,6 +34,7 @@ from src.particle_mesh_ewald import Particle_Mesh_Ewald
|
|||
|
||||
class controller:
|
||||
'''controller'''
|
||||
|
||||
def __init__(self, args_opt):
|
||||
self.input_file = args_opt.i
|
||||
self.initial_coordinates_file = args_opt.c
|
||||
|
@ -67,6 +68,7 @@ class controller:
|
|||
|
||||
class Simulation(nn.Cell):
|
||||
'''simulation'''
|
||||
|
||||
def __init__(self, args_opt):
|
||||
super(Simulation, self).__init__()
|
||||
self.control = controller(args_opt)
|
||||
|
@ -119,6 +121,7 @@ class Simulation(nn.Cell):
|
|||
self.exp_gamma = self.liujian_info.exp_gamma
|
||||
self.init_Tensor()
|
||||
self.op_define()
|
||||
self.update = False
|
||||
|
||||
def init_Tensor(self):
|
||||
'''init tensor'''
|
||||
|
@ -129,9 +132,12 @@ class Simulation(nn.Cell):
|
|||
self.uint_dr_to_dr_cof = Parameter(
|
||||
Tensor(np.asarray(self.md_info.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False)
|
||||
self.box_length = Tensor(self.md_info.box_length, mstype.float32)
|
||||
self.charge = Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32)
|
||||
self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32),
|
||||
requires_grad=False)
|
||||
self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
|
||||
requires_grad=False)
|
||||
self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
|
||||
requires_grad=False)
|
||||
self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32),
|
||||
requires_grad=False)
|
||||
self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32)
|
||||
|
@ -341,8 +347,65 @@ class Simulation(nn.Cell):
|
|||
acc = F.depend(self.acc, crd)
|
||||
return vel, crd, acc
|
||||
|
||||
def Main_Print(self, *args):
|
||||
"""compute the temperature"""
|
||||
steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
|
||||
nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene = list(args)
|
||||
if steps == 0:
|
||||
print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_")
|
||||
|
||||
temperature = temperature.asnumpy()
|
||||
total_potential_energy = total_potential_energy.asnumpy()
|
||||
print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)),
|
||||
end=" ")
|
||||
if self.bond.bond_numbers > 0:
|
||||
sigma_of_bond_ene = sigma_of_bond_ene.asnumpy()
|
||||
print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ")
|
||||
if self.angle.angle_numbers > 0:
|
||||
sigma_of_angle_ene = sigma_of_angle_ene.asnumpy()
|
||||
print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ")
|
||||
if self.dihedral.dihedral_numbers > 0:
|
||||
sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy()
|
||||
print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ")
|
||||
if self.nb14.nb14_numbers > 0:
|
||||
nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy()
|
||||
nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy()
|
||||
print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ")
|
||||
LJ_energy_sum = LJ_energy_sum.asnumpy()
|
||||
ee_ene = ee_ene.asnumpy()
|
||||
print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ")
|
||||
print("{:>12.3f}".format(float(ee_ene)))
|
||||
if self.file is not None:
|
||||
self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}"
|
||||
" {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy),
|
||||
float(sigma_of_bond_ene), float(sigma_of_angle_ene),
|
||||
float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum),
|
||||
float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene)))
|
||||
if self.datfile is not None:
|
||||
self.datfile.write(self.crd.asnumpy())
|
||||
|
||||
def Main_Initial(self):
|
||||
"""main initial"""
|
||||
if self.control.mdout:
|
||||
self.file = open(self.control.mdout, 'w')
|
||||
self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ "
|
||||
"_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n")
|
||||
if self.control.mdcrd:
|
||||
self.datfile = open(self.control.mdcrd, 'wb')
|
||||
|
||||
def Main_Destroy(self):
|
||||
"""main destroy"""
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
print("Save .out file successfully!")
|
||||
if self.datfile is not None:
|
||||
self.datfile.close()
|
||||
print("Save .dat file successfully!")
|
||||
|
||||
def construct(self, step, print_step):
|
||||
'''construct'''
|
||||
self.last_crd = self.crd
|
||||
if step == 0:
|
||||
res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
|
||||
self.box_length, self.grid_N, self.grid_length_inverse,
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.mdnn import Mdnn
|
||||
from mindspore import nn, Model, context
|
||||
from mindspore import dataset as ds
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train.callback import Callback
|
||||
import mindspore.common.initializer as weight_init
|
||||
|
||||
parser = argparse.ArgumentParser(description='Mdnn Controller')
|
||||
parser.add_argument('--i', type=str, default=None, help='Input radial and angular dat file')
|
||||
parser.add_argument('--charge', type=str, default=None, help='Input charge dat file')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
|
||||
|
||||
|
||||
class StepLossAccInfo(Callback):
|
||||
"""custom callback function"""
|
||||
|
||||
def __init__(self, models, eval_dataset, steploss):
|
||||
"""init model"""
|
||||
self.model = models
|
||||
self.eval_dataset = eval_dataset
|
||||
self.steps_loss = steploss
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num
|
||||
self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
|
||||
self.steps_loss["step"].append(str(cur_step))
|
||||
|
||||
|
||||
def get_data(inputdata, outputdata):
|
||||
"""get data function"""
|
||||
for _, data in enumerate(zip(inputdata, outputdata)):
|
||||
yield data
|
||||
|
||||
|
||||
def create_dataset(inputdata, outputdata, batchsize=32, repeat_size=1):
|
||||
"""create dataset function"""
|
||||
|
||||
input_data = ds.GeneratorDataset(list(get_data(inputdata, outputdata)), column_names=['data', 'label'])
|
||||
input_data = input_data.batch(batchsize)
|
||||
input_data = input_data.repeat(repeat_size)
|
||||
return input_data
|
||||
|
||||
|
||||
def init_weight(nnet):
|
||||
for _, cell in nnet.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# read input files
|
||||
inputs = args_opt.i
|
||||
outputs = args_opt.charge
|
||||
radial_angular = np.fromfile(inputs, dtype=np.float32)
|
||||
radial_angular = radial_angular.reshape((-1, 258)).astype(np.float32)
|
||||
charge = np.fromfile(outputs, dtype=np.float32)
|
||||
charge = charge.reshape((-1, 129)).astype(np.float32)
|
||||
# define the model
|
||||
net = Mdnn()
|
||||
lr = 0.0001
|
||||
decay_rate = 0.8
|
||||
epoch_size = 1000
|
||||
batch_size = 500
|
||||
total_step = epoch_size * batch_size
|
||||
step_per_epoch = 100
|
||||
decay_epoch = epoch_size
|
||||
lr_rate = nn.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
|
||||
net_loss = nn.loss.MSELoss(reduction='mean')
|
||||
net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_rate)
|
||||
model = Model(net, net_loss, net_opt)
|
||||
ds_train = create_dataset(radial_angular, charge, batchsize=batch_size)
|
||||
model_params = net.trainable_params()
|
||||
net.set_train()
|
||||
init_weight(net)
|
||||
# config files
|
||||
path = './params/'
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="mdnn_best", directory=path, config=config_ck)
|
||||
steps_loss = {"step": [], "loss_value": []}
|
||||
step_loss_acc_info = StepLossAccInfo(model, ds_train, steps_loss)
|
||||
# train the model
|
||||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(100)])
|
Loading…
Reference in New Issue