Merge pull request !23873 from jiahongQian/master_0918
This commit is contained in:
i-robot 2021-09-23 01:04:20 +00:00 committed by Gitee
commit d6d4a4a5ee
3 changed files with 8 additions and 3 deletions

View File

@ -29,6 +29,5 @@ MS_REG_GPU_KERNEL_TWO(RefreshUintCrd,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeUInt32),
RefreshUintCrdGpuKernel, float, unsigned int)
} // namespace kernel
} // namespace mindspore

View File

@ -2424,7 +2424,7 @@ class LJForce(PrimitiveWithInfer):
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
validator.check_int(len(d_lj_a), 1, Rel.EQ, "d_LJ_A_dim", cls_name)
validator.check_int(d_lj_a[0], q, Rel.EQ, "d_LJ_A_shape[0]", cls_name)
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)

View File

@ -2081,7 +2081,7 @@ class PMEEnergyUpdate(PrimitiveWithInfer):
validator.check_int(charge[0], n, Rel.EQ, "charge_shape", cls_name)
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape[0]", cls_name)
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], m, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(nl_serial[1], m, Rel.EQ, "nl_serial_shape[1]", cls_name)
validator.check_int(excluded_list_start[0], n, Rel.EQ, "excluded_list_start_shape", cls_name)
validator.check_int(excluded_atom_numbers[0], n, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
validator.check_int(excluded_list[0], e, Rel.EQ, "excluded_list_shape", cls_name)
@ -2272,6 +2272,8 @@ class ConstrainForceVirial(PrimitiveWithInfer):
validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
@ -2372,6 +2374,8 @@ class ConstrainForce(PrimitiveWithInfer):
validator.check_int(len(atom_j_serials_shape), 1, Rel.EQ, "atom_j_serials_dim", cls_name)
validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
@ -2480,6 +2484,8 @@ class Constrain(PrimitiveWithInfer):
validator.check_int(len(constant_rs_shape), 1, Rel.EQ, "constant_rs_dim", cls_name)
validator.check_int(len(constrain_ks_shape), 1, Rel.EQ, "constrain_ks_dim", cls_name)
validator.check_int(len(need_pressure), 1, Rel.LE, "need_pressure_dim", cls_name)
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)