diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cu new file mode 100644 index 00000000000..31d0867fc84 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cu @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +__device__ __host__ float fc(float Rij) { + const float PI = 3.141592654; + const float Rc = 1000.0; + return 0.5 * cosf(PI / Rc * Rij) + 0.5; +} + +__global__ void Record_Box_Map_Times(int atom_numbers, const float *crd, const float *old_crd, float *box, + int *box_map_times) { + float half_box[3] = {0.5 * box[0], 0.5 * box[1], 0.5 * box[2]}; + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < atom_numbers) { + if (crd[3 * i + 0] - old_crd[3 * i + 0] > half_box[0]) { + box_map_times[3 * i + 0] = box_map_times[3 * i + 0] - 1; + } else if (crd[3 * i + 0] - old_crd[3 * i + 0] < -half_box[0]) { + box_map_times[3 * i + 0] = box_map_times[3 * i + 0] + 1; + } + if (crd[3 * i + 1] - old_crd[3 * i + 1] > half_box[1]) { + box_map_times[3 * i + 1] = box_map_times[3 * i + 1] - 1; + } else if (crd[3 * i + 1] - old_crd[3 * i + 1] < -half_box[1]) { + box_map_times[3 * i + 1] = box_map_times[3 * i + 1] + 1; + } + if (crd[3 * i + 2] - old_crd[3 * i + 2] > half_box[2]) { + box_map_times[3 * i + 2] = box_map_times[3 * i + 2] - 1; + } else if (crd[3 * i + 2] - old_crd[3 * i + 2] < -half_box[2]) { + box_map_times[3 * i + 2] = box_map_times[3 * i + 2] + 1; + } + } +} + +__global__ void gen_nowarp_crd(int atom_numbers, const float *crd, float *box, int *box_map_times, float *nowarp_crd) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < atom_numbers) { + nowarp_crd[3 * i + 0] = static_cast(box_map_times[3 * i + 0]) * box[0] + crd[3 * i + 0]; + nowarp_crd[3 * i + 1] = static_cast(box_map_times[3 * i + 1]) * box[1] + crd[3 * i + 1]; + nowarp_crd[3 * i + 2] = static_cast(box_map_times[3 * i + 2]) * box[2] + crd[3 * i + 2]; + } +} + +__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) { + const float Rs = 0.5, Eta = 0.5; + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= start_serial && i < end_serial) { + float rij; + float g_radial_lin = 0.; + for (int j = start_serial; j < end_serial; j = j + 1) { + if (j != i) { + // rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j])); + rij = sqrtf(normfloat(crd, crd, i, j)); + g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij); + } else { + continue; + } + } + g_radial[i] = g_radial_lin; + } +} + +__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) { + const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0; + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= start_serial && i < end_serial) { + float rij, rik, rjk, theta_jik; + float g_angular_lin = 0.; + for (int j = start_serial; j < end_serial; j = j + 1) { + if (j != i) { + rij = sqrtf(normfloat(crd, crd, i, j)); + for (int k = j + 1; k < end_serial; k = k + 1) { + if (k != i) { + rik = sqrtf(normfloat(crd, crd, i, k)); + rjk = sqrtf(normfloat(crd, crd, j, k)); + theta_jik = + acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999)); + g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) * + expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik); + } else { + continue; + } + } + } else { + continue; + } + } + g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin; + } +} + +void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, + const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, + float *g_angular, cudaStream_t stream) { + Reset_List<<(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, box_map_times, + 0); + Record_Box_Map_Times<<(3. * atom_numbers) / 128), 128, 0, stream>>>( + atom_numbers, crd_f, old_crd, box, box_map_times); + gen_nowarp_crd<<(3. * atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, crd_f, box, + box_map_times, nowarp_crd); + G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_radial); + G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_angular); + return; +} + +void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, + const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, + float *g_angular, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh similarity index 56% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh index 4434a0d966a..f7edf8d7952 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh @@ -14,12 +14,13 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ #include "runtime/device/gpu/cuda_common.h" -void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, - cudaStream_t stream); +void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, + const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, + float *g_angular, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cu deleted file mode 100644 index 8f8ee0d043f..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cu +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh" -#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh" -#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" - -__device__ __host__ float fc(float Rij) { - const float PI = 3.141592654; - const float Rc = 1000.0; - return 0.5 * cosf(PI / Rc * Rij) + 0.5; -} - -__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) { - const float Rs = 0.5, Eta = 0.5; - int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i >= start_serial && i < end_serial) { - float rij; - float g_radial_lin = 0.; - for (int j = start_serial; j < end_serial; j = j + 1) { - if (j != i) { - // rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j])); - rij = sqrtf(normfloat(crd, crd, i, j)); - g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij); - } else { - continue; - } - } - g_radial[i] = g_radial_lin; - } -} - -__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) { - const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0; - int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i >= start_serial && i < end_serial) { - float rij, rik, rjk, theta_jik; - float g_angular_lin = 0.; - for (int j = start_serial; j < end_serial; j = j + 1) { - if (j != i) { - rij = sqrtf(normfloat(crd, crd, i, j)); - for (int k = j + 1; k < end_serial; k = k + 1) { - if (k != i) { - rik = sqrtf(normfloat(crd, crd, i, k)); - rjk = sqrtf(normfloat(crd, crd, j, k)); - theta_jik = - acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999)); - g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) * - expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik); - } else { - continue; - } - } - } else { - continue; - } - } - g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin; - } -} - -void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, - cudaStream_t stream) { - G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_radial); - G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_angular); - return; -} - -void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, - cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.cc similarity index 53% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.cc rename to mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.cc index eca4df71799..360227ce4f6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.cc @@ -14,13 +14,19 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h" +#include "backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h" namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_TWO( - TransferCrd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransferGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO(TransferCrd, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + AtomCrdToCVGpuKernel, float, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h similarity index 65% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h index 874de6a18e2..045036ca120 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ -#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh" #include #include @@ -31,19 +31,24 @@ namespace mindspore { namespace kernel { template -class TransferGpuKernel : public GpuKernel { +class AtomCrdToCVGpuKernel : public GpuKernel { public: - TransferGpuKernel() : ele_crd(1) {} - ~TransferGpuKernel() override = default; + AtomCrdToCVGpuKernel() : ele_crd(1) {} + ~AtomCrdToCVGpuKernel() override = default; bool Init(const CNodePtr &kernel_node) override { kernel_node_ = kernel_node; start_serial = static_cast(GetAttr(kernel_node, "start_serial")); end_serial = static_cast(GetAttr(kernel_node, "end_serial")); number = static_cast(GetAttr(kernel_node, "number")); + atom_numbers = static_cast(GetAttr(kernel_node, "atom_numbers")); auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape_old_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape_box = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i]; + for (size_t i = 0; i < shape_old_crd.size(); i++) ele_old_crd *= shape_old_crd[i]; + for (size_t i = 0; i < shape_box.size(); i++) ele_box *= shape_box[i]; InitSizeLists(); return true; @@ -53,27 +58,40 @@ class TransferGpuKernel : public GpuKernel { const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { auto crd = GetDeviceAddress(inputs, 0); + auto old_crd = GetDeviceAddress(inputs, 1); + auto box = GetDeviceAddress(inputs, 2); auto g_radial = GetDeviceAddress(outputs, 0); auto g_angular = GetDeviceAddress(outputs, 1); - Transfer(start_serial, end_serial, number, crd, g_radial, g_angular, reinterpret_cast(stream_ptr)); + auto nowarp_crd = GetDeviceAddress(outputs, 2); + auto box_map_times = GetDeviceAddress(outputs, 3); + + AtomCrdToCV(atom_numbers, start_serial, end_serial, number, crd, old_crd, nowarp_crd, box_map_times, box, g_radial, + g_angular, reinterpret_cast(stream_ptr)); return true; } protected: void InitSizeLists() override { input_size_list_.push_back(ele_crd * sizeof(T)); + input_size_list_.push_back(ele_old_crd * sizeof(T)); + input_size_list_.push_back(ele_box * sizeof(T)); output_size_list_.push_back(number * sizeof(T)); output_size_list_.push_back(number * sizeof(T)); + + output_size_list_.push_back(3 * atom_numbers * sizeof(T)); + output_size_list_.push_back(3 * atom_numbers * sizeof(T1)); } private: size_t ele_crd = 1; + size_t ele_old_crd = 1; + size_t ele_box = 1; std::vector input_size_list_; std::vector output_size_list_; @@ -81,7 +99,8 @@ class TransferGpuKernel : public GpuKernel { int end_serial; int start_serial; int number; + int atom_numbers; }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ diff --git a/mindspore/ops/operations/sponge_ops.py b/mindspore/ops/operations/sponge_ops.py index 92e6d25fc7a..0f5489f5dc8 100644 --- a/mindspore/ops/operations/sponge_ops.py +++ b/mindspore/ops/operations/sponge_ops.py @@ -2852,33 +2852,54 @@ class TransferCrd(PrimitiveWithInfer): Inputs: - **crd** (Tensor, float32) - [N, 3], the coordinate of each atom. + N is the number of atoms.. + - **old_crd** (Tensor, float32) - [N, 3], the last coordinate of each atom. N is the number of atoms. - + - **box** (Tensor, float32) - [3,], the length of 3 dimensions of the simulation box. Outputs: - - **output** (uint32) + - **radial** (Tensor, float32) - [number,], the array of radial transferred from coordinates. + - **angular** (Tensor, float32) - [number,], the array of angular transferred from coordinates. + - **nowarp_crd** (Tensor, float32) - [N, 3], the modified coordinate of each atom for + computing radial and angular. + - **box_map_times** (Tensor, int32) - [N, 3], the box map times for radial and angular. Supported Platforms: ``GPU`` """ @prim_attr_register - def __init__(self, start_serial, end_serial, number): + def __init__(self, start_serial, end_serial, number, atom_numbers): + validator.check_value_type('start_serial', start_serial, (int), self.name) + validator.check_value_type('end_serial', end_serial, (int), self.name) + validator.check_value_type('number', number, (int), self.name) + validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) self.start_serial = start_serial self.end_serial = end_serial self.number = number + self.atom_numbers = atom_numbers self.add_prim_attr('start_serial', self.start_serial) self.add_prim_attr('end_serial', self.end_serial) self.add_prim_attr('number', self.number) + self.add_prim_attr('atom_numbers', self.atom_numbers) self.init_prim_io_names( - inputs=['crd'], - outputs=['radial', 'angular']) + inputs=['crd', 'old_crd', 'box'], + outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times']) - def infer_shape(self, crd_shape): + def infer_shape(self, crd_shape, old_crd_shape, box_shape): + N = self.atom_numbers validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name) - validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name) - return [self.number,], [self.number,] + validator.check_int(crd_shape[0], N, Rel.EQ, "crd_shape[0]", self.name) + validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name) + validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name) + validator.check_int(old_crd_shape[0], N, Rel.EQ, "old_crd_shape[0]", self.name) + validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name) + validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name) + validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name) + return [self.number,], [self.number,], [N, 3], [N, 3] - def infer_dtype(self, crd_dtype): + def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype): validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name) - return mstype.float32, mstype.float32 + validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name) + validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name) + return mstype.float32, mstype.float32, mstype.float32, mstype.int32 diff --git a/model_zoo/research/hpc/sponge/main.py b/model_zoo/research/hpc/sponge/main.py index 2c1f9ca2e3d..9f37635f6c8 100644 --- a/model_zoo/research/hpc/sponge/main.py +++ b/model_zoo/research/hpc/sponge/main.py @@ -17,73 +17,57 @@ import argparse import time from src.simulation import Simulation - +from src.mdnn import Mdnn, TransCrdToCV import mindspore.context as context from mindspore import Tensor +from mindspore import load_checkpoint -parser = argparse.ArgumentParser(description='Sponge Controller') -parser.add_argument('--i', type=str, default=None, help='input file') -parser.add_argument('--amber_parm', type=str, default=None, help='paramter file in AMBER type') -parser.add_argument('--c', type=str, default=None, help='initial coordinates file') +parser = argparse.ArgumentParser(description='SPONGE Controller') +parser.add_argument('--i', type=str, default=None, help='Input file') +parser.add_argument('--amber_parm', type=str, default=None, help='Paramter file in AMBER type') +parser.add_argument('--c', type=str, default=None, help='Initial coordinates file') parser.add_argument('--r', type=str, default="restrt", help='') parser.add_argument('--x', type=str, default="mdcrd", help='') -parser.add_argument('--o', type=str, default="mdout", help="") +parser.add_argument('--o', type=str, default="mdout", help='Output file') parser.add_argument('--box', type=str, default="mdbox", help='') -parser.add_argument('--device_id', type=int, default=0, help='') +parser.add_argument('--device_id', type=int, default=0, help='GPU device id') +parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge') +parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file') args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False) if __name__ == "__main__": simulation = Simulation(args_opt) + if args_opt.u and args_opt.checkpoint: + net = Mdnn() + load_checkpoint(args_opt.checkpoint, net=net) + transcrd = TransCrdToCV(simulation) + start = time.time() compiler_time = 0 save_path = args_opt.o - file = open(save_path, 'w') + simulation.Main_Initial() for steps in range(simulation.md_info.step_limit): print_step = steps % simulation.ntwx if steps == simulation.md_info.step_limit - 1: print_step = 0 temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \ nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step)) + if steps == 0: compiler_time = time.time() if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1: - if steps == 0: - print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " - "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_") - file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " - "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n") + simulation.Main_Print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, + sigma_of_dihedral_ene, nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene) - temperature = temperature.asnumpy() - total_potential_energy = total_potential_energy.asnumpy() - print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), - end=" ") - if simulation.bond.bond_numbers > 0: - sigma_of_bond_ene = sigma_of_bond_ene.asnumpy() - print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ") - if simulation.angle.angle_numbers > 0: - sigma_of_angle_ene = sigma_of_angle_ene.asnumpy() - print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ") - if simulation.dihedral.dihedral_numbers > 0: - sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy() - print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ") - if simulation.nb14.nb14_numbers > 0: - nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy() - nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy() - print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ") - LJ_energy_sum = LJ_energy_sum.asnumpy() - ee_ene = ee_ene.asnumpy() - print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ") - print("{:>12.3f}".format(float(ee_ene))) - if file is not None: - file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}" - " {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy), - float(sigma_of_bond_ene), float(sigma_of_angle_ene), - float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum), - float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene))) + if args_opt.u and args_opt.checkpoint and steps % (4 * simulation.ntwx) == 0: + print("Update charge!") + inputs = transcrd(Tensor(simulation.crd), Tensor(simulation.last_crd)) + t_charge = net(inputs) + simulation.charge = transcrd.updatecharge(t_charge) end = time.time() - file.close() print("Main time(s):", end - start) print("Main time(s) without compiler:", end - compiler_time) + simulation.Main_Destroy() diff --git a/model_zoo/research/hpc/sponge/src/mdnn.py b/model_zoo/research/hpc/sponge/src/mdnn.py new file mode 100644 index 00000000000..0261d091e04 --- /dev/null +++ b/model_zoo/research/hpc/sponge/src/mdnn.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mdnn class""" + +import numpy as np +from mindspore import nn, Tensor +from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter +import mindspore.common.dtype as mstype + + +class Mdnn(nn.Cell): + """Mdnn""" + + def __init__(self, dim=258, dr=0.5): + super(Mdnn, self).__init__() + self.dim = dim + self.dr = dr # dropout_ratio + self.fc1 = nn.Dense(dim, 512) + self.fc2 = nn.Dense(512, 512) + self.fc3 = nn.Dense(512, 512) + self.fc4 = nn.Dense(512, 129) + self.tanh = nn.Tanh() + + def construct(self, x): + """construct""" + x = self.tanh(self.fc1(x)) + x = self.tanh(self.fc2(x)) + x = self.tanh(self.fc3(x)) + x = self.fc4(x) + return x + + +class TransCrdToCV(nn.Cell): + """TransCrdToCV""" + + def __init__(self, simulation): + super(TransCrdToCV, self).__init__() + self.atom_numbers = simulation.atom_numbers + self.transfercrd = P.TransferCrd(0, 129, 129, self.atom_numbers) + self.box = Tensor(simulation.box_length) + self.radial = Parameter(Tensor(np.zeros([129,]), mstype.float32)) + self.angular = Parameter(Tensor(np.zeros([129,]), mstype.float32)) + self.output = Parameter(Tensor(np.zeros([1, 258]), mstype.float32)) + self.charge = simulation.charge + + def updatecharge(self, t_charge): + """update charge in simulation""" + self.charge[:129] = t_charge[0] * 18.2223 + return self.charge + + def construct(self, crd, last_crd): + """construct""" + self.radial, self.angular, _, _ = self.transfercrd(crd, last_crd, self.box) + self.output = P.Concat()((self.radial, self.angular)) + self.output = P.ExpandDims()(self.output, 0) + return self.output diff --git a/model_zoo/research/hpc/sponge/src/simulation.py b/model_zoo/research/hpc/sponge/src/simulation.py index 507eb67ee06..3e920045de5 100644 --- a/model_zoo/research/hpc/sponge/src/simulation.py +++ b/model_zoo/research/hpc/sponge/src/simulation.py @@ -34,6 +34,7 @@ from src.particle_mesh_ewald import Particle_Mesh_Ewald class controller: '''controller''' + def __init__(self, args_opt): self.input_file = args_opt.i self.initial_coordinates_file = args_opt.c @@ -67,6 +68,7 @@ class controller: class Simulation(nn.Cell): '''simulation''' + def __init__(self, args_opt): super(Simulation, self).__init__() self.control = controller(args_opt) @@ -119,6 +121,7 @@ class Simulation(nn.Cell): self.exp_gamma = self.liujian_info.exp_gamma self.init_Tensor() self.op_define() + self.update = False def init_Tensor(self): '''init tensor''' @@ -129,9 +132,12 @@ class Simulation(nn.Cell): self.uint_dr_to_dr_cof = Parameter( Tensor(np.asarray(self.md_info.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False) self.box_length = Tensor(self.md_info.box_length, mstype.float32) - self.charge = Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32) + self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32), + requires_grad=False) self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), requires_grad=False) + self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), + requires_grad=False) self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32), requires_grad=False) self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32) @@ -341,8 +347,65 @@ class Simulation(nn.Cell): acc = F.depend(self.acc, crd) return vel, crd, acc + def Main_Print(self, *args): + """compute the temperature""" + steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \ + nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene = list(args) + if steps == 0: + print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " + "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_") + + temperature = temperature.asnumpy() + total_potential_energy = total_potential_energy.asnumpy() + print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), + end=" ") + if self.bond.bond_numbers > 0: + sigma_of_bond_ene = sigma_of_bond_ene.asnumpy() + print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ") + if self.angle.angle_numbers > 0: + sigma_of_angle_ene = sigma_of_angle_ene.asnumpy() + print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ") + if self.dihedral.dihedral_numbers > 0: + sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy() + print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ") + if self.nb14.nb14_numbers > 0: + nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy() + nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy() + print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ") + LJ_energy_sum = LJ_energy_sum.asnumpy() + ee_ene = ee_ene.asnumpy() + print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ") + print("{:>12.3f}".format(float(ee_ene))) + if self.file is not None: + self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}" + " {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy), + float(sigma_of_bond_ene), float(sigma_of_angle_ene), + float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum), + float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene))) + if self.datfile is not None: + self.datfile.write(self.crd.asnumpy()) + + def Main_Initial(self): + """main initial""" + if self.control.mdout: + self.file = open(self.control.mdout, 'w') + self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " + "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n") + if self.control.mdcrd: + self.datfile = open(self.control.mdcrd, 'wb') + + def Main_Destroy(self): + """main destroy""" + if self.file is not None: + self.file.close() + print("Save .out file successfully!") + if self.datfile is not None: + self.datfile.close() + print("Save .dat file successfully!") + def construct(self, step, print_step): '''construct''' + self.last_crd = self.crd if step == 0: res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd, self.box_length, self.grid_N, self.grid_length_inverse, diff --git a/model_zoo/research/hpc/sponge/train_mdnn.py b/model_zoo/research/hpc/sponge/train_mdnn.py new file mode 100644 index 00000000000..5b75d1a5b86 --- /dev/null +++ b/model_zoo/research/hpc/sponge/train_mdnn.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train""" +import argparse +import numpy as np +from src.mdnn import Mdnn +from mindspore import nn, Model, context +from mindspore import dataset as ds +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.callback import Callback +import mindspore.common.initializer as weight_init + +parser = argparse.ArgumentParser(description='Mdnn Controller') +parser.add_argument('--i', type=str, default=None, help='Input radial and angular dat file') +parser.add_argument('--charge', type=str, default=None, help='Input charge dat file') +parser.add_argument('--device_id', type=int, default=0, help='GPU device id') +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False) + + +class StepLossAccInfo(Callback): + """custom callback function""" + + def __init__(self, models, eval_dataset, steploss): + """init model""" + self.model = models + self.eval_dataset = eval_dataset + self.steps_loss = steploss + + def step_end(self, run_context): + """step end""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num + self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) + self.steps_loss["step"].append(str(cur_step)) + + +def get_data(inputdata, outputdata): + """get data function""" + for _, data in enumerate(zip(inputdata, outputdata)): + yield data + + +def create_dataset(inputdata, outputdata, batchsize=32, repeat_size=1): + """create dataset function""" + + input_data = ds.GeneratorDataset(list(get_data(inputdata, outputdata)), column_names=['data', 'label']) + input_data = input_data.batch(batchsize) + input_data = input_data.repeat(repeat_size) + return input_data + + +def init_weight(nnet): + for _, cell in nnet.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + + +if __name__ == '__main__': + # read input files + inputs = args_opt.i + outputs = args_opt.charge + radial_angular = np.fromfile(inputs, dtype=np.float32) + radial_angular = radial_angular.reshape((-1, 258)).astype(np.float32) + charge = np.fromfile(outputs, dtype=np.float32) + charge = charge.reshape((-1, 129)).astype(np.float32) + # define the model + net = Mdnn() + lr = 0.0001 + decay_rate = 0.8 + epoch_size = 1000 + batch_size = 500 + total_step = epoch_size * batch_size + step_per_epoch = 100 + decay_epoch = epoch_size + lr_rate = nn.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch) + net_loss = nn.loss.MSELoss(reduction='mean') + net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_rate) + model = Model(net, net_loss, net_opt) + ds_train = create_dataset(radial_angular, charge, batchsize=batch_size) + model_params = net.trainable_params() + net.set_train() + init_weight(net) + # config files + path = './params/' + config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10) + ckpoint_cb = ModelCheckpoint(prefix="mdnn_best", directory=path, config=config_ck) + steps_loss = {"step": [], "loss_value": []} + step_loss_acc_info = StepLossAccInfo(model, ds_train, steps_loss) + # train the model + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(100)])