forked from mindspore-Ecosystem/mindspore
fix constrain
This commit is contained in:
parent
ac82cc3221
commit
08bfb9e9b9
|
@ -3123,8 +3123,8 @@ class NeighborListUpdate(PrimitiveWithInfer):
|
|||
The data type is int32 and the shape is :math:`(n,)`.
|
||||
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
|
||||
The data type is int32 and the shape is :math:`(1,)`.
|
||||
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
|
||||
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
|
||||
- **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
|
||||
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
|
||||
|
||||
Outputs:
|
||||
- **res** (Tensor) - The return value after updating successfully.
|
||||
|
|
|
@ -873,7 +873,7 @@ class NeighborListRefresh(PrimitiveWithInfer):
|
|||
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
|
||||
The data type is int32 and the shape is :math:`(1,)`.
|
||||
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
|
||||
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
|
||||
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
|
||||
|
||||
Outputs:
|
||||
- **res** (Tensor) - The return value after updating successfully.
|
||||
|
@ -2292,7 +2292,7 @@ class ConstrainForceVirial(PrimitiveWithInfer):
|
|||
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
|
||||
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
|
||||
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
|
||||
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
|
||||
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
|
||||
|
@ -2330,8 +2330,12 @@ class ConstrainForce(PrimitiveWithInfer):
|
|||
half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
|
||||
|
||||
Inputs:
|
||||
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
|
||||
The data type is uint32 and the shape is :math:`(n, 3)`.
|
||||
- **crd** (Tensor) - The coordinate of each atom.
|
||||
The data type is float32 and the shape is :math:`(n, 3)`.
|
||||
- **quarter_cof** (Tensor) - The 3-D scale factor.
|
||||
The data type is float32 and the shape is :math:`(3,)`.
|
||||
- **mass_inverse** (Tensor) - The inverse value of mass of each atom.
|
||||
The data type is float32 and the shape is :math:`(n,)`.
|
||||
- **scaler** (Tensor) - The 3-D scale factor (x, y, z),
|
||||
The data type is float32 and the shape is :math:`(3,)`.
|
||||
- **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
|
||||
|
@ -2394,7 +2398,7 @@ class ConstrainForce(PrimitiveWithInfer):
|
|||
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
|
||||
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
|
||||
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
|
||||
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
|
||||
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
|
||||
|
@ -2430,10 +2434,15 @@ class Constrain(PrimitiveWithInfer):
|
|||
constrain_pair_numbers (int32): the number of constrain pairs m.
|
||||
iteration_numbers (int32): the number of iteration numbers p.
|
||||
half_exp_gamma_plus_half (float32): half exp_gamma plus half q.
|
||||
update_interval (int32): the number of update interval, default 10.
|
||||
|
||||
Inputs:
|
||||
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
|
||||
The data type is uint32 and the shape is :math:`(n, 3)`.
|
||||
- **crd** (Tensor) - The coordinate of each atom.
|
||||
The data type is float32 and the shape is :math:`(n, 3)`.
|
||||
- **quarter_cof** (Tensor) - The 3-D scale factor.
|
||||
The data type is float32 and the shape is :math:`(3,)`.
|
||||
- **mass_inverse** (Tensor) - The inverse value of mass of each atom.
|
||||
The data type is float32 and the shape is :math:`(n,)`.
|
||||
- **scaler** (Tensor) - The 3-D scale factor (x, y, z),
|
||||
The data type is float32 and the shape is :math:`(3,)`.
|
||||
- **pair_dr** (Tensor) - The displacement vector of each constrained atom pair.
|
||||
|
@ -2447,7 +2456,7 @@ class Constrain(PrimitiveWithInfer):
|
|||
- **constrain_ks** (Tensor) - The coefficient of each constrained atom pair.
|
||||
The data type is float32 and the shape is :math:`(m,)`.
|
||||
- **need_pressure** (Tensor) - If need pressure, 1 else 0.
|
||||
The data type is int32 and the shape is :math:`(1,)`.
|
||||
The data type is int32 and the shape is :math:`(1,)` or :math:`()`.
|
||||
|
||||
Outputs:
|
||||
- **uint_crd** (Tensor) - The unsigned int coordinate value of each atom.
|
||||
|
@ -2504,7 +2513,7 @@ class Constrain(PrimitiveWithInfer):
|
|||
validator.check_int(crd[0], n, Rel.EQ, "crd_shape[0]", cls_name)
|
||||
validator.check_int(crd[1], 3, Rel.EQ, "crd_shape[1]", cls_name)
|
||||
validator.check_int(quarter_cof[0], 3, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "quarter_cof_shape", cls_name)
|
||||
validator.check_int(mass_inverse[0], n, Rel.EQ, "mass_inverse_shape", cls_name)
|
||||
validator.check_int(scaler_shape[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
||||
validator.check_int(pair_dr_shape[0], m, Rel.EQ, "pair_dr_shape[0]", cls_name)
|
||||
validator.check_int(pair_dr_shape[1], 3, Rel.EQ, "pair_dr_shape[1]", cls_name)
|
||||
|
@ -2512,6 +2521,8 @@ class Constrain(PrimitiveWithInfer):
|
|||
validator.check_int(atom_j_serials_shape[0], m, Rel.EQ, "atom_j_serials_shape[0]", cls_name)
|
||||
validator.check_int(constant_rs_shape[0], m, Rel.EQ, "constant_rs_shape[0]", cls_name)
|
||||
validator.check_int(constrain_ks_shape[0], m, Rel.EQ, "constrain_ks_shape[0]", cls_name)
|
||||
if need_pressure:
|
||||
validator.check_int(need_pressure[0], 1, Rel.EQ, "need_pressure_shape", self.name)
|
||||
return [n, 3], [n, 3], [m,]
|
||||
|
||||
def infer_dtype(self, crd, quarter_cof, mass_inverse, scaler_dtype, pair_dr_dtype, atom_i_serials_dtype,
|
||||
|
|
Loading…
Reference in New Issue