From 4e6b73937adafb9020ace12e1c31f78e716a1164 Mon Sep 17 00:00:00 2001 From: zhangxinfeng3 Date: Fri, 26 Mar 2021 10:13:42 +0800 Subject: [PATCH] update spongeQm --- .../gpu/sponge/pme/pme_energy_kernel.h | 5 +- .../sponge/pme/pme_excluded_force_kernel.h | 5 +- mindspore/ops/operations/sponge_ops.py | 638 ++++++++++++------ .../hpc/sponge/src/particle_mesh_ewald.py | 9 +- .../hpc/sponge/src/simulation_initial.py | 4 +- 5 files changed, 431 insertions(+), 230 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_energy_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_energy_kernel.h index 51b49c41e24..4f151040ba3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_energy_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_energy_kernel.h @@ -35,6 +35,7 @@ class PMEEnergyGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { kernel_node_ = kernel_node; atom_numbers = static_cast(GetAttr(kernel_node, "atom_numbers")); + excluded_numbers = static_cast(GetAttr(kernel_node, "excluded_numbers")); beta = static_cast(GetAttr(kernel_node, "beta")); fftx = static_cast(GetAttr(kernel_node, "fftx")); ffty = static_cast(GetAttr(kernel_node, "ffty")); @@ -90,7 +91,7 @@ class PMEEnergyGpuKernel : public GpuKernel { input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(max_nl_numbers * sizeof(T1)); - input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); + input_size_list_.push_back(sizeof(VECTOR)); input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(excluded_numbers * sizeof(T1)); @@ -118,7 +119,7 @@ class PMEEnergyGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; int atom_numbers; - int excluded_numbers = 2719; + int excluded_numbers; int max_nl_numbers = 800; int fftx; int ffty; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_excluded_force_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_excluded_force_kernel.h index f8b4486210b..b7becdeaaa6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_excluded_force_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/pme/pme_excluded_force_kernel.h @@ -35,6 +35,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { kernel_node_ = kernel_node; atom_numbers = static_cast(GetAttr(kernel_node, "atom_numbers")); + excluded_numbers = static_cast(GetAttr(kernel_node, "excluded_numbers")); beta = static_cast(GetAttr(kernel_node, "beta")); InitSizeLists(); return true; @@ -62,7 +63,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(atom_numbers * sizeof(UNSIGNED_INT_VECTOR)); - input_size_list_.push_back(atom_numbers * sizeof(VECTOR)); + input_size_list_.push_back(sizeof(VECTOR)); input_size_list_.push_back(atom_numbers * sizeof(T)); input_size_list_.push_back(atom_numbers * sizeof(T1)); input_size_list_.push_back(excluded_numbers * sizeof(T1)); @@ -77,7 +78,7 @@ class PMEExcludedForceGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; int atom_numbers; - int excluded_numbers = 2719; + int excluded_numbers; float beta; struct VECTOR { float x; diff --git a/mindspore/ops/operations/sponge_ops.py b/mindspore/ops/operations/sponge_ops.py index 9ce217ff4ef..20fdf561952 100644 --- a/mindspore/ops/operations/sponge_ops.py +++ b/mindspore/ops/operations/sponge_ops.py @@ -69,13 +69,19 @@ class BondForce(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers M = self.bond_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) - validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name) + validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) + validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) return uint_crd_f_shape def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): @@ -136,13 +142,19 @@ class BondEnergy(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers M = self.bond_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) - validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name) + validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) + validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) return bond_k_shape @@ -198,13 +210,19 @@ class BondAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers M = self.bond_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) - validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name) + validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) + validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) return [N,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): @@ -259,13 +277,19 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers M = self.bond_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) - validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name) + validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) + validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) return uint_crd_f_shape, [N,] @@ -333,13 +357,19 @@ class BondForceWithAtomVirial(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers M = self.bond_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name) - validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name) + validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) + validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) return uint_crd_f_shape, [N,] @@ -436,17 +466,29 @@ class DihedralForce(PrimitiveWithInfer): ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): cls_name = self.name M = self.dihedral_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) - validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) - validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) - validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) - validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) - validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name) + validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name) + validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name) + validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name) + validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name) + validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) + validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) + validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) + validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) + validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) + validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) return uint_crd_f_shape def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, @@ -516,17 +558,29 @@ class DihedralEnergy(PrimitiveWithInfer): ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): cls_name = self.name M = self.dihedral_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) - validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) - validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) - validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) - validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) - validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name) + validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name) + validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name) + validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name) + validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name) + validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) + validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) + validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) + validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) + validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) + validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) return [M,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, @@ -595,17 +649,29 @@ class DihedralAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = uint_crd_f_shape[0] M = self.dihedral_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) - validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) - validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) - validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) - validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) - validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name) + validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name) + validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name) + validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name) + validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name) + validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) + validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) + validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) + validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) + validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) + validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) return [N,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, @@ -673,17 +739,29 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = uint_crd_f_shape[0] M = self.dihedral_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name) - validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name) - validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name) - validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name) - validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name) - validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name) + validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name) + validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name) + validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name) + validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name) + validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) + validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) + validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) + validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) + validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) + validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) return uint_crd_f_shape, [N,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, @@ -757,13 +835,21 @@ class AngleForce(PrimitiveWithInfer): angle_theta0_shape): cls_name = self.name M = self.angle_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) - validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name) + validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) + validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) return uint_crd_f_shape def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, @@ -825,13 +911,21 @@ class AngleEnergy(PrimitiveWithInfer): angle_theta0_shape): cls_name = self.name M = self.angle_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) - validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name) + validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) + validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) return [M,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, @@ -888,13 +982,21 @@ class AngleAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = uint_crd_f_shape[0] M = self.angle_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) - validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name) + validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) + validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) return [N,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, @@ -951,13 +1053,21 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer): cls_name = self.name N = uint_crd_f_shape[0] M = self.angle_numbers - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name) - validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name) - validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name) - validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name) - validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name) - validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name) + validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name) + validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name) + validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name) + validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name) + validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name) + + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) + validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) + validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) + validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) + validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) + validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) return uint_crd_f_shape, [N,] def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, @@ -1029,16 +1139,21 @@ class Dihedral14LJForce(PrimitiveWithInfer): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers Q = LJ_type_A_shape[0] + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) + validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) return uint_crd_f_shape @@ -1112,16 +1227,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers Q = LJ_type_A_shape[0] + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) + validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) return [self.dihedral_14_numbers,] @@ -1201,16 +1321,22 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): N = self.atom_numbers M = self.dihedral_14_numbers Q = LJ_type_A_shape[0] - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) - validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) - validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) - validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) + validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) + validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) + + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name) + validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) + validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) return [self.atom_numbers, 3] def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, @@ -1286,18 +1412,23 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers Q = LJ_type_A_shape[0] - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) - validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) - validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) - validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) + validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) + validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) + + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) + validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) return uint_crd_f_shape, charge_shape def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, @@ -1366,17 +1497,22 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers Q = LJ_type_A_shape[0] - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) - validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name) - validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) + validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) + + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) + validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) return LJtype_shape def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, @@ -1445,15 +1581,19 @@ class Dihedral14CFEnergy(PrimitiveWithInfer): cf_scale_factor_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) - validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) + + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) return [self.dihedral_14_numbers,] def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, @@ -1516,15 +1656,19 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer): cf_scale_factor_shape): cls_name = self.name N = self.atom_numbers - M = self.dihedral_14_numbers - validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name) - validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) - validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name) - validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) - validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name) - validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name) - validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name) + validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) + validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name) + validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) + validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) + validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) + + validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) + validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) + validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) return LJtype_shape def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, @@ -1672,10 +1816,14 @@ class PMEReciprocalForce(PrimitiveWithInfer): def infer_shape(self, boxlength_shape, uint_crd_shape, charge_shape): cls_name = self.name N = self.atom_numbers - validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength", cls_name) - validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) + validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(boxlength_shape), 1, Rel.EQ, "boxlength_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + + validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) return uint_crd_shape def infer_dtype(self, boxlength_type, uint_crd_type, charge_type): @@ -1693,6 +1841,7 @@ class PMEExcludedForce(PrimitiveWithInfer): Args: atom_numbers(int32): the number of atoms, N. + excluded_numbers(int32): the length of excluded list, E. beta(float32): the PME beta parameter, determined by the non-bond cutoff value and simulation precision tolerance. @@ -1716,27 +1865,36 @@ class PMEExcludedForce(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, atom_numbers, beta): + def __init__(self, atom_numbers, excluded_numbers, beta): validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) + validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name) validator.check_value_type('beta', beta, (float), self.name) self.atom_numbers = atom_numbers + self.excluded_numbers = excluded_numbers self.beta = beta self.init_prim_io_names( inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'], outputs=['force']) self.add_prim_attr('atom_numbers', self.atom_numbers) + self.add_prim_attr('excluded_numbers', self.excluded_numbers) self.add_prim_attr('beta', self.beta) def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape, excluded_atom_numbers_shape): cls_name = self.name N = self.atom_numbers - validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler", cls_name) - validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) - validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start", cls_name) - validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers", cls_name) + validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name) + validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name) + validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name) + + validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name) + validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start_shape", cls_name) + validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name) return uint_crd_shape def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type, @@ -1763,6 +1921,7 @@ class PMEEnergy(PrimitiveWithInfer): Args: atom_numbers(int32): the number of atoms, N. + excluded_numbers(int32): the length of excluded list, E. beta(float32): the PME beta parameter, determined by the non-bond cutoff value and simulation precision tolerance. fftx(int32): the number of points for Fourier transform in dimension X. @@ -1795,13 +1954,15 @@ class PMEEnergy(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, atom_numbers, beta, fftx, ffty, fftz): + def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz): validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) + validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name) validator.check_value_type('beta', beta, (float), self.name) validator.check_value_type('fftx', fftx, (int), self.name) validator.check_value_type('ffty', ffty, (int), self.name) validator.check_value_type('fftz', fftz, (int), self.name) self.atom_numbers = atom_numbers + self.excluded_numbers = excluded_numbers self.beta = beta self.fftx = fftx self.ffty = ffty @@ -1811,6 +1972,7 @@ class PMEEnergy(PrimitiveWithInfer): 'excluded_list', 'excluded_atom_numbers'], outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene']) self.add_prim_attr('atom_numbers', self.atom_numbers) + self.add_prim_attr('excluded_numbers', self.excluded_numbers) self.add_prim_attr('beta', self.beta) self.add_prim_attr('fftx', self.fftx) self.add_prim_attr('ffty', self.ffty) @@ -1820,16 +1982,25 @@ class PMEEnergy(PrimitiveWithInfer): excluded_list, excluded_atom_numbers): cls_name = self.name N = self.atom_numbers - validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(box_length[0], 3, Rel.EQ, "box_length", cls_name) - validator.check_int(charge[0], N, Rel.EQ, "charge", cls_name) - validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers[0]", cls_name) - validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial[0]", cls_name) - validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) - validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start", cls_name) - validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers", cls_name) - validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list", cls_name) + validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(box_length), 1, Rel.EQ, "sacler_dim", cls_name) + validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) + validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name) + validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name) + validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name) + validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name) + + validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(box_length[0], 3, Rel.EQ, "box_length_shape", cls_name) + validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape[0]", cls_name) + validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial_shape[0]", cls_name) + validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) + validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start_shape", cls_name) + validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name) + validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name) return (1,), (1,), (1,), (1,) def infer_dtype(self, box_length, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start, @@ -1906,16 +2077,25 @@ class LJEnergy(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers Q = d_LJ_A[0] - validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) - validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) - validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) + validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) + validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) + + validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name) + validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name) + validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) return charge def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): @@ -1983,16 +2163,25 @@ class LJForce(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers Q = d_LJ_A[0] - validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) - validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) - validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) + validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) + validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) + + validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name) + validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name) + validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) return uint_crd def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): @@ -2058,16 +2247,25 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer): cls_name = self.name N = self.atom_numbers Q = d_LJ_A[0] - validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name) - validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name) - validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name) - validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name) - validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name) - validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name) - validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name) - validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name) + validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) + validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name) + validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) + validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) + validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) + validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) + + validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name) + validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) + validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name) + validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name) + validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name) + validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) + validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) + validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) return uint_crd def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): diff --git a/model_zoo/research/hpc/sponge/src/particle_mesh_ewald.py b/model_zoo/research/hpc/sponge/src/particle_mesh_ewald.py index 80130940a67..833a680d6e9 100644 --- a/model_zoo/research/hpc/sponge/src/particle_mesh_ewald.py +++ b/model_zoo/research/hpc/sponge/src/particle_mesh_ewald.py @@ -119,18 +119,19 @@ class Particle_Mesh_Ewald(nn.Cell): tempi += 4 def PME_Energy(self, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start, - excluded_list, excluded_numbers): + excluded_list, excluded_numbers, excluded_atom_numbers): """PME_Energy""" - self.pmee = P.PMEEnergy(self.atom_numbers, self.beta, self.fftx, self.ffty, self.fftz) + self.pmee = P.PMEEnergy(self.atom_numbers, excluded_atom_numbers, self.beta, self.fftx, self.ffty, self.fftz) self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy = \ self.pmee(self.box_length, uint_crd, charge, nl_atom_numbers, nl_atom_serial, uint_dr_to_dr_cof, excluded_list_start, excluded_list, excluded_numbers) return self.reciprocal_energy, self.self_energy, self.direct_energy, self.correction_energy def PME_Excluded_Force(self, uint_crd, scaler, charge, excluded_list_start, excluded_list, - excluded_numbers): + excluded_numbers, excluded_atom_numbers): """PME Excluded Force""" - self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, beta=self.beta) + self.pmeef = P.PMEExcludedForce(atom_numbers=self.atom_numbers, excluded_numbers=excluded_atom_numbers, + beta=self.beta) self.frc = self.pmeef(uint_crd, scaler, charge, excluded_list_start, excluded_list, excluded_numbers) return self.frc diff --git a/model_zoo/research/hpc/sponge/src/simulation_initial.py b/model_zoo/research/hpc/sponge/src/simulation_initial.py index f8b3846f2b7..03771c80359 100644 --- a/model_zoo/research/hpc/sponge/src/simulation_initial.py +++ b/model_zoo/research/hpc/sponge/src/simulation_initial.py @@ -124,7 +124,7 @@ class Simulation(nn.Cell): pme_excluded_frc = pme_method.PME_Excluded_Force( md_info.uint_crd, md_info.uint_dr_to_dr_cof, md_info.charge, nb_info.excluded_list_start, nb_info.excluded_list, - nb_info.excluded_numbers) + nb_info.excluded_numbers, nb_info.excluded_atom_numbers) frc_t += pme_excluded_frc.asnumpy() pme_reciprocal_frc = pme_method.PME_Reciprocal_Force(md_info.uint_crd, md_info.charge) @@ -149,7 +149,7 @@ class Simulation(nn.Cell): _ = self.pme_method.PME_Energy( self.md_info.uint_crd, self.md_info.charge, self.nb_info.nl_atom_numbers, self.nb_info.nl_atom_serial, self.md_info.uint_dr_to_dr_cof, self.nb_info.excluded_list_start, self.nb_info.excluded_list, - self.nb_info.excluded_numbers) + self.nb_info.excluded_numbers, self.nb_info.excluded_atom_numbers) _ = self.pme_method.Energy_Device_To_Host() def Main_After_Calculate_Energy(self):