fix constrain

This commit is contained in:
q00596439 2021-09-29 17:16:08 +08:00
parent ac82cc3221
commit 08bfb9e9b9
2 changed files with 22 additions and 11 deletions

View File

@ -3123,8 +3123,8 @@ class NeighborListUpdate(PrimitiveWithInfer):
The data type is int32 and the shape is :math:`(n,)`.
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
- **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
Outputs:
- **res** (Tensor) - The return value after updating successfully.

View File

@ -873,7 +873,7 @@ class NeighborListRefresh(PrimitiveWithInfer):
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
Outputs:
- **res** (Tensor) - The return value after updating successfully.
@ -2292,7 +2292,7 @@ class ConstrainForceVirial(PrimitiveWithInfer):
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
@ -2330,8 +2330,12 @@ class ConstrainForce(PrimitiveWithInfer):
half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
Inputs:
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
The data type is uint32 and the shape is :math:`(n, 3)`.
- **crd** (Tensor) - The coordinate of each atom.
The data type is float32 and the shape is :math:`(n, 3)`.
- **quarter_cof** (Tensor) - The 3-D scale factor.
The data type is float32 and the shape is :math:`(3,)`.
- **mass_inverse** (Tensor) - The inverse value of mass of each atom.
The data type is float32 and the shape is :math:`(n,)`.
- **scaler** (Tensor) - The 3-D scale factor (x, y, z),
The data type is float32 and the shape is :math:`(3,)`.
- **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
@ -2394,7 +2398,7 @@ class ConstrainForce(PrimitiveWithInfer):
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
@ -2430,10 +2434,15 @@ class Constrain(PrimitiveWithInfer):
constrain_pair_numbers (int32): the number of constrain pairs m.
iteration_numbers (int32): the number of iteration numbers p.
half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
update_interval (int32): the number of update interval, default 10.
Inputs:
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
The data type is uint32 and the shape is :math:`(n, 3)`.
- **crd** (Tensor) - The coordinate of each atom.
The data type is float32 and the shape is :math:`(n, 3)`.
- **quarter_cof** (Tensor) - The 3-D scale factor.
The data type is float32 and the shape is :math:`(3,)`.
- **mass_inverse** (Tensor) - The inverse value of mass of each atom.
The data type is float32 and the shape is :math:`(n,)`.
- **scaler** (Tensor) - The 3-D scale factor (x, y, z),
The data type is float32 and the shape is :math:`(3,)`.
- **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
@ -2447,7 +2456,7 @@ class Constrain(PrimitiveWithInfer):
- **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
The data type is float32 and the shape is :math:`(m,)`.
- **need_pressure** (Tensor) - If need pressure, 1 else 0.
The data type is int32 and the shape is :math:`(1,)`.
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
Outputs:
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
@ -2504,7 +2513,7 @@ class Constrain(PrimitiveWithInfer):
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
@ -2512,6 +2521,8 @@ class Constrain(PrimitiveWithInfer):
validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
if need_pressure:
validator.check_int(need_pressure[0], 1, Rel.EQ, "need_pressure_shape", self.name)
return [n, 3], [n, 3], [m,]
def infer_dtype(self, crd, quarter_cof, mass_inverse, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,