forked from mindspore-Ecosystem/mindspore
!14144 update sponge
From: @zhangxinfeng3 Reviewed-by: @wang_zi_dong,@ljl0711 Signed-off-by: @ljl0711
This commit is contained in:
commit
0a0dc05d51
|
@ -35,6 +35,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
|
||||
excluded_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "excluded_numbers"));
|
||||
beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta"));
|
||||
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
|
||||
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
|
||||
|
@ -90,7 +91,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
|
|||
input_size_list_.push_back(atom_numbers * sizeof(VECTOR));
|
||||
input_size_list_.push_back(atom_numbers * sizeof(T1));
|
||||
input_size_list_.push_back(max_nl_numbers * sizeof(T1));
|
||||
input_size_list_.push_back(atom_numbers * sizeof(VECTOR));
|
||||
input_size_list_.push_back(sizeof(VECTOR));
|
||||
|
||||
input_size_list_.push_back(atom_numbers * sizeof(T1));
|
||||
input_size_list_.push_back(excluded_numbers * sizeof(T1));
|
||||
|
@ -118,7 +119,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
int atom_numbers;
|
||||
int excluded_numbers = 2719;
|
||||
int excluded_numbers;
|
||||
int max_nl_numbers = 800;
|
||||
int fftx;
|
||||
int ffty;
|
||||
|
|
|
@ -35,6 +35,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
|
||||
excluded_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "excluded_numbers"));
|
||||
beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta"));
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -62,7 +63,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
|
|||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(atom_numbers * sizeof(UNSIGNED_INT_VECTOR));
|
||||
input_size_list_.push_back(atom_numbers * sizeof(VECTOR));
|
||||
input_size_list_.push_back(sizeof(VECTOR));
|
||||
input_size_list_.push_back(atom_numbers * sizeof(T));
|
||||
input_size_list_.push_back(atom_numbers * sizeof(T1));
|
||||
input_size_list_.push_back(excluded_numbers * sizeof(T1));
|
||||
|
@ -77,7 +78,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
int atom_numbers;
|
||||
int excluded_numbers = 2719;
|
||||
int excluded_numbers;
|
||||
float beta;
|
||||
struct VECTOR {
|
||||
float x;
|
||||
|
|
|
@ -69,13 +69,19 @@ class BondForce(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.bond_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
||||
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
||||
return uint_crd_f_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
|
||||
|
@ -136,13 +142,19 @@ class BondEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.bond_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
||||
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
||||
|
||||
return bond_k_shape
|
||||
|
||||
|
@ -198,13 +210,19 @@ class BondAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.bond_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
||||
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
||||
return [N,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
|
||||
|
@ -259,13 +277,19 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.bond_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
||||
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
||||
|
||||
return uint_crd_f_shape, [N,]
|
||||
|
||||
|
@ -333,13 +357,19 @@ class BondForceWithAtomVirial(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.bond_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
||||
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
||||
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
||||
|
||||
return uint_crd_f_shape, [N,]
|
||||
|
||||
|
@ -436,17 +466,29 @@ class DihedralForce(PrimitiveWithInfer):
|
|||
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
|
||||
cls_name = self.name
|
||||
M = self.dihedral_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
||||
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
||||
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
||||
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
||||
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
||||
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
||||
return uint_crd_f_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
||||
|
@ -516,17 +558,29 @@ class DihedralEnergy(PrimitiveWithInfer):
|
|||
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
|
||||
cls_name = self.name
|
||||
M = self.dihedral_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
||||
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
||||
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
||||
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
||||
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
||||
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
||||
return [M,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
||||
|
@ -595,17 +649,29 @@ class DihedralAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = uint_crd_f_shape[0]
|
||||
M = self.dihedral_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
||||
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
||||
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
||||
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
||||
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
||||
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
||||
return [N,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
||||
|
@ -673,17 +739,29 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = uint_crd_f_shape[0]
|
||||
M = self.dihedral_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
||||
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
||||
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
||||
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
||||
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
||||
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
||||
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
||||
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
||||
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
||||
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
||||
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
||||
return uint_crd_f_shape, [N,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
||||
|
@ -757,13 +835,21 @@ class AngleForce(PrimitiveWithInfer):
|
|||
angle_theta0_shape):
|
||||
cls_name = self.name
|
||||
M = self.angle_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
||||
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
||||
return uint_crd_f_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
||||
|
@ -825,13 +911,21 @@ class AngleEnergy(PrimitiveWithInfer):
|
|||
angle_theta0_shape):
|
||||
cls_name = self.name
|
||||
M = self.angle_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
||||
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
||||
return [M,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
||||
|
@ -888,13 +982,21 @@ class AngleAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = uint_crd_f_shape[0]
|
||||
M = self.angle_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
||||
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
||||
return [N,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
||||
|
@ -951,13 +1053,21 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = uint_crd_f_shape[0]
|
||||
M = self.angle_numbers
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
||||
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
||||
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
||||
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
||||
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
||||
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
||||
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
||||
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
||||
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
||||
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
||||
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
||||
return uint_crd_f_shape, [N,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
||||
|
@ -1029,16 +1139,21 @@ class Dihedral14LJForce(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
return uint_crd_f_shape
|
||||
|
||||
|
@ -1112,16 +1227,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
return [self.dihedral_14_numbers,]
|
||||
|
||||
|
@ -1201,16 +1321,22 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
|
|||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
return [self.atom_numbers, 3]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1286,18 +1412,23 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
return uint_crd_f_shape, charge_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1366,17 +1497,22 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
|
|||
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
Q = LJ_type_A_shape[0]
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
||||
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
||||
return LJtype_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1445,15 +1581,19 @@ class Dihedral14CFEnergy(PrimitiveWithInfer):
|
|||
cf_scale_factor_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
return [self.dihedral_14_numbers,]
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1516,15 +1656,19 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
|
|||
cf_scale_factor_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
M = self.dihedral_14_numbers
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
||||
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
||||
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
||||
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
||||
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
||||
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
||||
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
||||
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
||||
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
||||
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
||||
return LJtype_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
||||
|
@ -1672,10 +1816,14 @@ class PMEReciprocalForce(PrimitiveWithInfer):
|
|||
def infer_shape(self, boxlength_shape, uint_crd_shape, charge_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(boxlength_shape), 1, Rel.EQ, "boxlength_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
return uint_crd_shape
|
||||
|
||||
def infer_dtype(self, boxlength_type, uint_crd_type, charge_type):
|
||||
|
@ -1693,6 +1841,7 @@ class PMEExcludedForce(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
atom_numbers(int32): the number of atoms, N.
|
||||
excluded_numbers(int32): the length of excluded list, E.
|
||||
beta(float32): the PME beta parameter, determined by the
|
||||
non-bond cutoff value and simulation precision tolerance.
|
||||
|
||||
|
@ -1716,27 +1865,36 @@ class PMEExcludedForce(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, beta):
|
||||
def __init__(self, atom_numbers, excluded_numbers, beta):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
||||
validator.check_value_type('beta', beta, (float), self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.excluded_numbers = excluded_numbers
|
||||
self.beta = beta
|
||||
self.init_prim_io_names(
|
||||
inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'],
|
||||
outputs=['force'])
|
||||
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
||||
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
|
||||
self.add_prim_attr('beta', self.beta)
|
||||
|
||||
def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
|
||||
excluded_atom_numbers_shape):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start", cls_name)
|
||||
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
|
||||
validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name)
|
||||
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
|
||||
validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name)
|
||||
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
|
||||
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
|
||||
return uint_crd_shape
|
||||
|
||||
def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type,
|
||||
|
@ -1763,6 +1921,7 @@ class PMEEnergy(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
atom_numbers(int32): the number of atoms, N.
|
||||
excluded_numbers(int32): the length of excluded list, E.
|
||||
beta(float32): the PME beta parameter, determined by the
|
||||
non-bond cutoff value and simulation precision tolerance.
|
||||
fftx(int32): the number of points for Fourier transform in dimension X.
|
||||
|
@ -1795,13 +1954,15 @@ class PMEEnergy(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, beta, fftx, ffty, fftz):
|
||||
def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
||||
validator.check_value_type('beta', beta, (float), self.name)
|
||||
validator.check_value_type('fftx', fftx, (int), self.name)
|
||||
validator.check_value_type('ffty', ffty, (int), self.name)
|
||||
validator.check_value_type('fftz', fftz, (int), self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.excluded_numbers = excluded_numbers
|
||||
self.beta = beta
|
||||
self.fftx = fftx
|
||||
self.ffty = ffty
|
||||
|
@ -1811,6 +1972,7 @@ class PMEEnergy(PrimitiveWithInfer):
|
|||
'excluded_list', 'excluded_atom_numbers'],
|
||||
outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
|
||||
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
||||
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
|
||||
self.add_prim_attr('beta', self.beta)
|
||||
self.add_prim_attr('fftx', self.fftx)
|
||||
self.add_prim_attr('ffty', self.ffty)
|
||||
|
@ -1820,16 +1982,25 @@ class PMEEnergy(PrimitiveWithInfer):
|
|||
excluded_list, excluded_atom_numbers):
|
||||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(box_length[0], 3, Rel.EQ, "box_length", cls_name)
|
||||
validator.check_int(charge[0], N, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers[0]", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
||||
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start", cls_name)
|
||||
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
|
||||
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list", cls_name)
|
||||
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(box_length), 1, Rel.EQ, "sacler_dim", cls_name)
|
||||
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||
validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name)
|
||||
validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
|
||||
validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
|
||||
validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name)
|
||||
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(box_length[0], 3, Rel.EQ, "box_length_shape", cls_name)
|
||||
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape[0]", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial_shape[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
|
||||
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
|
||||
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name)
|
||||
return (1,), (1,), (1,), (1,)
|
||||
|
||||
def infer_dtype(self, box_length, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
|
||||
|
@ -1906,16 +2077,25 @@ class LJEnergy(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
Q = d_LJ_A[0]
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
||||
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||
return charge
|
||||
|
||||
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
||||
|
@ -1983,16 +2163,25 @@ class LJForce(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
Q = d_LJ_A[0]
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
||||
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||
return uint_crd
|
||||
|
||||
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
||||
|
@ -2058,16 +2247,25 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
|
|||
cls_name = self.name
|
||||
N = self.atom_numbers
|
||||
Q = d_LJ_A[0]
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
||||
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
||||
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
||||
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
||||
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||
|
||||
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
||||
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
||||
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||
return uint_crd
|
||||
|
||||
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
||||
|
|
|
@ -119,18 +119,19 @@ class Particle_Mesh_Ewald(nn.Cell):
|
|||
tempi += 4
|
||||
|
||||
def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start,
|
||||
excluded_list, excluded_numbers):
|
||||
excluded_list, excluded_numbers, excluded_atom_numbers):
|
||||
"""PME_Energy"""
|
||||
self.pmee = P.PMEEnergy(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
|
||||
self.pmee = P.PMEEnergy(self.atom_numbers, excluded_atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
|
||||
self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \
|
||||
self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof,
|
||||
excluded_list_start, excluded_list, excluded_numbers)
|
||||
return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy
|
||||
|
||||
def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list,
|
||||
excluded_numbers):
|
||||
excluded_numbers, excluded_atom_numbers):
|
||||
"""PME Excluded Force"""
|
||||
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, beta=self.beta)
|
||||
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, excluded_numbers=excluded_atom_numbers,
|
||||
beta=self.beta)
|
||||
self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers)
|
||||
return self.frc
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ class Simulation(nn.Cell):
|
|||
pme_excluded_frc = pme_method.PME_Excluded_Force(
|
||||
md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge,
|
||||
nb_info.excluded_list_start, nb_info.excluded_list,
|
||||
nb_info.excluded_numbers)
|
||||
nb_info.excluded_numbers, nb_info.excluded_atom_numbers)
|
||||
frc_t += pme_excluded_frc.asnumpy()
|
||||
|
||||
pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge)
|
||||
|
@ -149,7 +149,7 @@ class Simulation(nn.Cell):
|
|||
_ = self.pme_method.PME_Energy(
|
||||
self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
|
||||
self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list,
|
||||
self.nb_info.excluded_numbers)
|
||||
self.nb_info.excluded_numbers, self.nb_info.excluded_atom_numbers)
|
||||
_ = self.pme_method.Energy_Device_To_Host()
|
||||
|
||||
def Main_After_Calculate_Energy(self):
|
||||
|
|
Loading…
Reference in New Issue