forked from mindspore-Ecosystem/mindspore
04084 sponge operators
This commit is contained in:
parent
a3fc997c4e
commit
5694cf4753
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
|
||||
__device__ __host__ float fc(float Rij) {
|
||||
const float PI = 3.141592654;
|
||||
const float Rc = 1000.0;
|
||||
return 0.5 * cosf(PI / Rc * Rij) + 0.5;
|
||||
}
|
||||
|
||||
__global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) {
|
||||
const float Rs = 0.5, Eta = 0.5;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij;
|
||||
float g_radial_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
// rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j]));
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_radial[i] = g_radial_lin;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) {
|
||||
const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0;
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= start_serial && i < end_serial) {
|
||||
float rij, rik, rjk, theta_jik;
|
||||
float g_angular_lin = 0.;
|
||||
for (int j = start_serial; j < end_serial; j = j + 1) {
|
||||
if (j != i) {
|
||||
rij = sqrtf(normfloat(crd, crd, i, j));
|
||||
for (int k = j + 1; k < end_serial; k = k + 1) {
|
||||
if (k != i) {
|
||||
rik = sqrtf(normfloat(crd, crd, i, k));
|
||||
rjk = sqrtf(normfloat(crd, crd, j, k));
|
||||
theta_jik =
|
||||
acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999));
|
||||
g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) *
|
||||
expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin;
|
||||
}
|
||||
}
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream) {
|
||||
G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_radial);
|
||||
G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_angular);
|
||||
return;
|
||||
}
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular,
|
||||
cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_
|
|
@ -128,6 +128,14 @@ __device__ __host__ static inline VECTOR operator^(const VECTOR &veca, const VEC
|
|||
return vec;
|
||||
}
|
||||
|
||||
__device__ __host__ static inline float normfloat(const float *x, const float *y, int i, int j) {
|
||||
float s = 0;
|
||||
s += (x[3 * i + 0] - y[3 * j + 0]) * (x[3 * i + 0] - y[3 * j + 0]);
|
||||
s += (x[3 * i + 1] - y[3 * j + 1]) * (x[3 * i + 1] - y[3 * j + 1]);
|
||||
s += (x[3 * i + 2] - y[3 * j + 2]) * (x[3 * i + 2] - y[3 * j + 2]);
|
||||
return s;
|
||||
}
|
||||
|
||||
__global__ static void construct_neighbor_list_kernel(int atom_numbers, int max_neighbor_numbers, int *nl_atom_numbers,
|
||||
int *nl_atom_serial, NEIGHBOR_LIST *nl) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < atom_numbers; i += gridDim.x * blockDim.x) {
|
||||
|
|
|
@ -56,6 +56,7 @@ void Dihedral14CFEnergy(const int dihedral_14_numbers, const int atom_numbers, c
|
|||
const int *b_14, const float *cf_scale_factor, float *ene, cudaStream_t stream) {
|
||||
size_t thread_per_block = 32;
|
||||
size_t block_per_grid = ceilf(static_cast<float>(dihedral_14_numbers) / 32);
|
||||
|
||||
UNSIGNED_INT_VECTOR *uint_crd =
|
||||
const_cast<UNSIGNED_INT_VECTOR *>(reinterpret_cast<const UNSIGNED_INT_VECTOR *>(uint_crd_f));
|
||||
|
||||
|
@ -65,6 +66,7 @@ void Dihedral14CFEnergy(const int dihedral_14_numbers, const int atom_numbers, c
|
|||
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
|
||||
|
||||
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
|
||||
|
||||
Dihedral14CFEnergyKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
dihedral_14_numbers, uint_crd_with_LJ, boxlength, a_14, b_14, cf_scale_factor, ene);
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ void Dihedral14LJCFForceWithAtomEnergy(const int dihedral_14_numbers, const int
|
|||
atom_numbers, uint_crd, uint_crd_with_LJ, LJtype, charge);
|
||||
|
||||
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, frc_f, 0.);
|
||||
Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, atom_energy, 0.);
|
||||
Reset_List<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, atom_energy, 0.);
|
||||
VECTOR *boxlength = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(boxlength_f));
|
||||
VECTOR *frc = const_cast<VECTOR *>(reinterpret_cast<const VECTOR *>(frc_f));
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
TransferCrd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TransferGpuKernel, float, int)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename T1>
|
||||
class TransferGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TransferGpuKernel() : ele_crd(1) {}
|
||||
~TransferGpuKernel() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
start_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "start_serial"));
|
||||
end_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "end_serial"));
|
||||
number = static_cast<int>(GetAttr<int64_t>(kernel_node, "number"));
|
||||
auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i];
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto crd = GetDeviceAddress<const T>(inputs, 0);
|
||||
|
||||
auto g_radial = GetDeviceAddress<T>(outputs, 0);
|
||||
auto g_angular = GetDeviceAddress<T>(outputs, 1);
|
||||
|
||||
Transfer(start_serial, end_serial, number, crd, g_radial, g_angular, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(ele_crd * sizeof(T));
|
||||
|
||||
output_size_list_.push_back(number * sizeof(T));
|
||||
output_size_list_.push_back(number * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t ele_crd = 1;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
int end_serial;
|
||||
int start_serial;
|
||||
int number;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_
|
|
@ -106,7 +106,7 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
|
|||
Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
|
||||
Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy, MDIterationLeapFrog,
|
||||
GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian,
|
||||
CrdToUintCrd, MDIterationSetupRandState)
|
||||
CrdToUintCrd, MDIterationSetupRandState, TransferCrd)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -469,6 +469,8 @@ __all__ = [
|
|||
"MDIterationLeapFrogLiujian",
|
||||
"CrdToUintCrd",
|
||||
"MDIterationSetupRandState",
|
||||
"TransferCrd",
|
||||
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -1141,6 +1141,7 @@ class Dihedral14LJForce(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
|
@ -1157,6 +1158,9 @@ class Dihedral14LJForce(PrimitiveWithInfer):
|
|||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name)
|
||||
return uint_crd_f_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1229,6 +1233,7 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
|
@ -1245,6 +1250,9 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
|
|||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name)
|
||||
return [self.dihedral_14_numbers,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1336,9 +1344,13 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
|
|||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name)
|
||||
return [self.atom_numbers, 3]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1414,6 +1426,7 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
|
@ -1431,6 +1444,10 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name)
|
||||
return uint_crd_f_shape, charge_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1515,6 +1532,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
|
|||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
M = self.dihedral_14_numbers
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name)
|
||||
return LJtype_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1596,6 +1617,10 @@ class Dihedral14CFEnergy(PrimitiveWithInfer):
|
|||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
M = self.dihedral_14_numbers
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name)
|
||||
return [self.dihedral_14_numbers,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1671,6 +1696,10 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
|
|||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
M = self.dihedral_14_numbers
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name)
|
||||
return LJtype_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -2674,14 +2703,15 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
|
|||
scheme for efficient configurational sampling for classical/quantum canonical
|
||||
ensembles via molecular dynamics. DOI: 10.1063/1.4991621.
|
||||
|
||||
Inputs:
|
||||
- **atom_numbers** (int32) - the number of atoms N.
|
||||
- **dt** (float32) - time step for finite difference.
|
||||
- **half_dt** (float32) - half of time step for finite difference.
|
||||
- **exp_gamma** (float32) - parameter in Liu's dynamic, equals
|
||||
Args:
|
||||
atom_numbers(int32): the number of atoms N.
|
||||
dt(float32): time step for finite difference.
|
||||
half_dt(float32): half of time step for finite difference.
|
||||
exp_gamma(float32): parameter in Liu's dynamic, equals
|
||||
exp(-gamma_ln * dt), where gamma_ln is the firction factor in Langvin
|
||||
dynamics.
|
||||
|
||||
Inputs:
|
||||
- **inverse_mass** (Tensor, float32) - [N,], the inverse value of
|
||||
mass of each atom.
|
||||
- **sqrt_mass_inverse** (Tensor, float32) - [N,], the inverse square root value
|
||||
|
@ -2699,7 +2729,6 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
|
|||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
Examples:
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -2735,14 +2764,14 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name)
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class CrdToUintCrd(PrimitiveWithInfer):
|
||||
"""
|
||||
Convert FP32 coordinate to Uint32 coordinate.
|
||||
|
||||
Inputs:
|
||||
- **atom_numbers** (int32) - the number of atoms N.
|
||||
Args:
|
||||
atom_numbers(int32): the number of atoms N.
|
||||
|
||||
Inputs:
|
||||
- **crd_to_uint_crd_cof** (Tensor, float32) - [3,], the .
|
||||
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom.
|
||||
|
||||
|
@ -2751,7 +2780,6 @@ class CrdToUintCrd(PrimitiveWithInfer):
|
|||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
Examples:
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -2763,7 +2791,7 @@ class CrdToUintCrd(PrimitiveWithInfer):
|
|||
outputs=['output'])
|
||||
|
||||
def infer_shape(self, crd_to_uint_crd_cof, crd):
|
||||
validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name)
|
||||
validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof_shape", self.name)
|
||||
validator.check_int(crd[0], self.atom_numbers, Rel.EQ, "crd[0]", self.name)
|
||||
validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name)
|
||||
return crd
|
||||
|
@ -2773,21 +2801,19 @@ class CrdToUintCrd(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name)
|
||||
return mstype.uint32
|
||||
|
||||
|
||||
class MDIterationSetupRandState(PrimitiveWithInfer):
|
||||
"""
|
||||
Convert FP32 coordinate to Uint32 coordinate.
|
||||
|
||||
Inputs:
|
||||
- **atom_numbers** (int32) - the number of atoms N.
|
||||
- **seed** (int32) - random seed.
|
||||
Args:
|
||||
atom_numbers(int32): the number of atoms N.
|
||||
seed(int32): random seed.
|
||||
|
||||
Outputs:
|
||||
- **output** (uint32) random state.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
Examples:
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -2808,3 +2834,45 @@ class MDIterationSetupRandState(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self):
|
||||
return mstype.float32
|
||||
|
||||
class TransferCrd(PrimitiveWithInfer):
|
||||
"""
|
||||
Transfer the coordinates to angular and radial.
|
||||
|
||||
Args:
|
||||
start_serial(int32): the index start position.
|
||||
end_serial(int32): the index end position.
|
||||
number(int32): the length of angular and radial.
|
||||
|
||||
Inputs:
|
||||
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom.
|
||||
N is the number of atoms.
|
||||
|
||||
|
||||
Outputs:
|
||||
- **output** (uint32)
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, start_serial, end_serial, number):
|
||||
self.start_serial = start_serial
|
||||
self.end_serial = end_serial
|
||||
self.number = number
|
||||
self.add_prim_attr('start_serial', self.start_serial)
|
||||
self.add_prim_attr('end_serial', self.end_serial)
|
||||
self.add_prim_attr('number', self.number)
|
||||
self.init_prim_io_names(
|
||||
inputs=['crd'],
|
||||
outputs=['radial', 'angular'])
|
||||
|
||||
def infer_shape(self, crd_shape):
|
||||
validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name)
|
||||
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name)
|
||||
return [self.number,], [self.number,]
|
||||
|
||||
def infer_dtype(self, crd_dtype):
|
||||
validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
|
||||
return mstype.float32, mstype.float32
|
||||
|
|
Loading…
Reference in New Issue