!14144 update sponge

From: @zhangxinfeng3
Reviewed-by: @wang_zi_dong,@ljl0711
Signed-off-by: @ljl0711
This commit is contained in:
mindspore-ci-bot 2021-03-27 15:59:13 +08:00 committed by Gitee
commit 0a0dc05d51
5 changed files with 431 additions and 230 deletions

View File

@ -35,6 +35,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers")); atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
excluded_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "excluded_numbers"));
beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta")); beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta"));
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx")); fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty")); ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
@ -90,7 +91,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); input_size_list_.push_back(atom_numbers * sizeof(VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(atom_numbers * sizeof(T1));
input_size_list_.push_back(max_nl_numbers * sizeof(T1)); input_size_list_.push_back(max_nl_numbers * sizeof(T1));
input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); input_size_list_.push_back(sizeof(VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(atom_numbers * sizeof(T1));
input_size_list_.push_back(excluded_numbers * sizeof(T1)); input_size_list_.push_back(excluded_numbers * sizeof(T1));
@ -118,7 +119,7 @@ class PMEEnergyGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
int atom_numbers; int atom_numbers;
int excluded_numbers = 2719; int excluded_numbers;
int max_nl_numbers = 800; int max_nl_numbers = 800;
int fftx; int fftx;
int ffty; int ffty;

View File

@ -35,6 +35,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node; kernel_node_ = kernel_node;
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers")); atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
excluded_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "excluded_numbers"));
beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta")); beta = static_cast<float>(GetAttr<float_t>(kernel_node, "beta"));
InitSizeLists(); InitSizeLists();
return true; return true;
@ -62,7 +63,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
protected: protected:
void InitSizeLists() override { void InitSizeLists() override {
input_size_list_.push_back(atom_numbers * sizeof(UNSIGNED_INT_VECTOR)); input_size_list_.push_back(atom_numbers * sizeof(UNSIGNED_INT_VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); input_size_list_.push_back(sizeof(VECTOR));
input_size_list_.push_back(atom_numbers * sizeof(T)); input_size_list_.push_back(atom_numbers * sizeof(T));
input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(atom_numbers * sizeof(T1));
input_size_list_.push_back(excluded_numbers * sizeof(T1)); input_size_list_.push_back(excluded_numbers * sizeof(T1));
@ -77,7 +78,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
int atom_numbers; int atom_numbers;
int excluded_numbers = 2719; int excluded_numbers;
float beta; float beta;
struct VECTOR { struct VECTOR {
float x; float x;

View File

@ -69,13 +69,19 @@ class BondForce(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.bond_numbers M = self.bond_numbers
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
return uint_crd_f_shape return uint_crd_f_shape
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
@ -136,13 +142,19 @@ class BondEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.bond_numbers M = self.bond_numbers
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
return bond_k_shape return bond_k_shape
@ -198,13 +210,19 @@ class BondAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.bond_numbers M = self.bond_numbers
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
return [N,] return [N,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
@ -259,13 +277,19 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.bond_numbers M = self.bond_numbers
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
return uint_crd_f_shape, [N,] return uint_crd_f_shape, [N,]
@ -333,13 +357,19 @@ class BondForceWithAtomVirial(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.bond_numbers M = self.bond_numbers
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
return uint_crd_f_shape, [N,] return uint_crd_f_shape, [N,]
@ -436,17 +466,29 @@ class DihedralForce(PrimitiveWithInfer):
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
cls_name = self.name cls_name = self.name
M = self.dihedral_numbers M = self.dihedral_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
return uint_crd_f_shape return uint_crd_f_shape
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
@ -516,17 +558,29 @@ class DihedralEnergy(PrimitiveWithInfer):
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
cls_name = self.name cls_name = self.name
M = self.dihedral_numbers M = self.dihedral_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
return [M,] return [M,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
@ -595,17 +649,29 @@ class DihedralAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = uint_crd_f_shape[0] N = uint_crd_f_shape[0]
M = self.dihedral_numbers M = self.dihedral_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
return [N,] return [N,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
@ -673,17 +739,29 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = uint_crd_f_shape[0] N = uint_crd_f_shape[0]
M = self.dihedral_numbers M = self.dihedral_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
return uint_crd_f_shape, [N,] return uint_crd_f_shape, [N,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
@ -757,13 +835,21 @@ class AngleForce(PrimitiveWithInfer):
angle_theta0_shape): angle_theta0_shape):
cls_name = self.name cls_name = self.name
M = self.angle_numbers M = self.angle_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
return uint_crd_f_shape return uint_crd_f_shape
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
@ -825,13 +911,21 @@ class AngleEnergy(PrimitiveWithInfer):
angle_theta0_shape): angle_theta0_shape):
cls_name = self.name cls_name = self.name
M = self.angle_numbers M = self.angle_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
return [M,] return [M,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
@ -888,13 +982,21 @@ class AngleAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = uint_crd_f_shape[0] N = uint_crd_f_shape[0]
M = self.angle_numbers M = self.angle_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
return [N,] return [N,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
@ -951,13 +1053,21 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = uint_crd_f_shape[0] N = uint_crd_f_shape[0]
M = self.angle_numbers M = self.angle_numbers
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
return uint_crd_f_shape, [N,] return uint_crd_f_shape, [N,]
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
@ -1029,16 +1139,21 @@ class Dihedral14LJForce(PrimitiveWithInfer):
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers
Q = LJ_type_A_shape[0] Q = LJ_type_A_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
return uint_crd_f_shape return uint_crd_f_shape
@ -1112,16 +1227,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers
Q = LJ_type_A_shape[0] Q = LJ_type_A_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
return [self.dihedral_14_numbers,] return [self.dihedral_14_numbers,]
@ -1201,16 +1321,22 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers M = self.dihedral_14_numbers
Q = LJ_type_A_shape[0] Q = LJ_type_A_shape[0]
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
return [self.atom_numbers, 3] return [self.atom_numbers, 3]
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
@ -1286,18 +1412,23 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers
Q = LJ_type_A_shape[0] Q = LJ_type_A_shape[0]
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
return uint_crd_f_shape, charge_shape return uint_crd_f_shape, charge_shape
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
@ -1366,17 +1497,22 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers
Q = LJ_type_A_shape[0] Q = LJ_type_A_shape[0]
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
return LJtype_shape return LJtype_shape
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
@ -1445,15 +1581,19 @@ class Dihedral14CFEnergy(PrimitiveWithInfer):
cf_scale_factor_shape): cf_scale_factor_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
return [self.dihedral_14_numbers,] return [self.dihedral_14_numbers,]
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
@ -1516,15 +1656,19 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
cf_scale_factor_shape): cf_scale_factor_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
M = self.dihedral_14_numbers validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
return LJtype_shape return LJtype_shape
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
@ -1672,10 +1816,14 @@ class PMEReciprocalForce(PrimitiveWithInfer):
def infer_shape(self, boxlength_shape, uint_crd_shape, charge_shape): def infer_shape(self, boxlength_shape, uint_crd_shape, charge_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(boxlength_shape), 1, Rel.EQ, "boxlength_dim", cls_name)
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
return uint_crd_shape return uint_crd_shape
def infer_dtype(self, boxlength_type, uint_crd_type, charge_type): def infer_dtype(self, boxlength_type, uint_crd_type, charge_type):
@ -1693,6 +1841,7 @@ class PMEExcludedForce(PrimitiveWithInfer):
Args: Args:
atom_numbers(int32): the number of atoms, N. atom_numbers(int32): the number of atoms, N.
excluded_numbers(int32): the length of excluded list, E.
beta(float32): the PME beta parameter, determined by the beta(float32): the PME beta parameter, determined by the
non-bond cutoff value and simulation precision tolerance. non-bond cutoff value and simulation precision tolerance.
@ -1716,27 +1865,36 @@ class PMEExcludedForce(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, atom_numbers, beta): def __init__(self, atom_numbers, excluded_numbers, beta):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
validator.check_value_type('beta', beta, (float), self.name) validator.check_value_type('beta', beta, (float), self.name)
self.atom_numbers = atom_numbers self.atom_numbers = atom_numbers
self.excluded_numbers = excluded_numbers
self.beta = beta self.beta = beta
self.init_prim_io_names( self.init_prim_io_names(
inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'], inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'],
outputs=['force']) outputs=['force'])
self.add_prim_attr('atom_numbers', self.atom_numbers) self.add_prim_attr('atom_numbers', self.atom_numbers)
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
self.add_prim_attr('beta', self.beta) self.add_prim_attr('beta', self.beta)
def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape, def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
excluded_atom_numbers_shape): excluded_atom_numbers_shape):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name)
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler", cls_name) validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start", cls_name) validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name)
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
return uint_crd_shape return uint_crd_shape
def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type, def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type,
@ -1763,6 +1921,7 @@ class PMEEnergy(PrimitiveWithInfer):
Args: Args:
atom_numbers(int32): the number of atoms, N. atom_numbers(int32): the number of atoms, N.
excluded_numbers(int32): the length of excluded list, E.
beta(float32): the PME beta parameter, determined by the beta(float32): the PME beta parameter, determined by the
non-bond cutoff value and simulation precision tolerance. non-bond cutoff value and simulation precision tolerance.
fftx(int32): the number of points for Fourier transform in dimension X. fftx(int32): the number of points for Fourier transform in dimension X.
@ -1795,13 +1954,15 @@ class PMEEnergy(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, atom_numbers, beta, fftx, ffty, fftz): def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
validator.check_value_type('beta', beta, (float), self.name) validator.check_value_type('beta', beta, (float), self.name)
validator.check_value_type('fftx', fftx, (int), self.name) validator.check_value_type('fftx', fftx, (int), self.name)
validator.check_value_type('ffty', ffty, (int), self.name) validator.check_value_type('ffty', ffty, (int), self.name)
validator.check_value_type('fftz', fftz, (int), self.name) validator.check_value_type('fftz', fftz, (int), self.name)
self.atom_numbers = atom_numbers self.atom_numbers = atom_numbers
self.excluded_numbers = excluded_numbers
self.beta = beta self.beta = beta
self.fftx = fftx self.fftx = fftx
self.ffty = ffty self.ffty = ffty
@ -1811,6 +1972,7 @@ class PMEEnergy(PrimitiveWithInfer):
'excluded_list', 'excluded_atom_numbers'], 'excluded_list', 'excluded_atom_numbers'],
outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene']) outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
self.add_prim_attr('atom_numbers', self.atom_numbers) self.add_prim_attr('atom_numbers', self.atom_numbers)
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
self.add_prim_attr('beta', self.beta) self.add_prim_attr('beta', self.beta)
self.add_prim_attr('fftx', self.fftx) self.add_prim_attr('fftx', self.fftx)
self.add_prim_attr('ffty', self.ffty) self.add_prim_attr('ffty', self.ffty)
@ -1820,16 +1982,25 @@ class PMEEnergy(PrimitiveWithInfer):
excluded_list, excluded_atom_numbers): excluded_list, excluded_atom_numbers):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(box_length), 1, Rel.EQ, "sacler_dim", cls_name)
validator.check_int(box_length[0], 3, Rel.EQ, "box_length", cls_name) validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge[0], N, Rel.EQ, "charge", cls_name) validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers[0]", cls_name) validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name)
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial[0]", cls_name) validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start", cls_name) validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name)
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list", cls_name) validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(box_length[0], 3, Rel.EQ, "box_length_shape", cls_name)
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape[0]", cls_name)
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name)
return (1,), (1,), (1,), (1,) return (1,), (1,), (1,), (1,)
def infer_dtype(self, box_length, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start, def infer_dtype(self, box_length, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
@ -1906,16 +2077,25 @@ class LJEnergy(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
Q = d_LJ_A[0] Q = d_LJ_A[0]
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return charge return charge
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
@ -1983,16 +2163,25 @@ class LJForce(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
Q = d_LJ_A[0] Q = d_LJ_A[0]
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return uint_crd return uint_crd
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
@ -2058,16 +2247,25 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
N = self.atom_numbers N = self.atom_numbers
Q = d_LJ_A[0] Q = d_LJ_A[0]
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return uint_crd return uint_crd
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):

View File

@ -119,18 +119,19 @@ class Particle_Mesh_Ewald(nn.Cell):
tempi += 4 tempi += 4
def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start, def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start,
excluded_list, excluded_numbers): excluded_list, excluded_numbers, excluded_atom_numbers):
"""PME_Energy""" """PME_Energy"""
self.pmee = P.PMEEnergy(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz) self.pmee = P.PMEEnergy(self.atom_numbers, excluded_atom_numbers, self.beta, self.fftx, self.ffty, self.fftz)
self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \ self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \
self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof,
excluded_list_start, excluded_list, excluded_numbers) excluded_list_start, excluded_list, excluded_numbers)
return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy
def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list, def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list,
excluded_numbers): excluded_numbers, excluded_atom_numbers):
"""PME Excluded Force""" """PME Excluded Force"""
self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, beta=self.beta) self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, excluded_numbers=excluded_atom_numbers,
beta=self.beta)
self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers) self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers)
return self.frc return self.frc

View File

@ -124,7 +124,7 @@ class Simulation(nn.Cell):
pme_excluded_frc = pme_method.PME_Excluded_Force( pme_excluded_frc = pme_method.PME_Excluded_Force(
md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge, md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge,
nb_info.excluded_list_start, nb_info.excluded_list, nb_info.excluded_list_start, nb_info.excluded_list,
nb_info.excluded_numbers) nb_info.excluded_numbers, nb_info.excluded_atom_numbers)
frc_t += pme_excluded_frc.asnumpy() frc_t += pme_excluded_frc.asnumpy()
pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge) pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge)
@ -149,7 +149,7 @@ class Simulation(nn.Cell):
_ = self.pme_method.PME_Energy( _ = self.pme_method.PME_Energy(
self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial, self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial,
self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list, self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list,
self.nb_info.excluded_numbers) self.nb_info.excluded_numbers, self.nb_info.excluded_atom_numbers)
_ = self.pme_method.Energy_Device_To_Host() _ = self.pme_method.Energy_Device_To_Host()
def Main_After_Calculate_Energy(self): def Main_After_Calculate_Energy(self):