add nxy check in SPONGE

This commit is contained in:
q00596439 2021-09-28 17:35:05 +08:00
parent b16a6a4245
commit ac82cc3221
2 changed files with 10 additions and 8 deletions

View File

@ -3084,13 +3084,13 @@ class NeighborListUpdate(PrimitiveWithInfer):
half_cutoff_with_skin(float32): cutoff_with_skin/2.
cutoff_with_skin_square(float32): the square value of cutoff_with_skin.
refresh_interval(int32): the number of iteration steps between two updates of neighbor list.
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid.
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid k.
Inputs:
- **atom_numbers_in_grid_bucket** (Tensor) - The number of atoms in each grid bucket.
The data type is int32 and the shape is :math:`(G,)`.
- **bucket** (Tensor) - (Tensor) - The atom indices in each grid bucket.
The data type is int32 and the shape is :math:`(G, m)`.
The data type is int32 and the shape is :math:`(G, k)`.
- **crd** (Tensor) - The coordinates of each atom.
The data type is float32 and the shape is :math:`(n, 3)`.
- **box_length** (Tensor) - The box length of the simulation box.
@ -3123,8 +3123,8 @@ class NeighborListUpdate(PrimitiveWithInfer):
The data type is int32 and the shape is :math:`(n,)`.
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
Outputs:
- **res** (Tensor) - The return value after updating successfully.
@ -3160,6 +3160,7 @@ class NeighborListUpdate(PrimitiveWithInfer):
validator.check_value_type('cutoff', cutoff, float, self.name)
validator.check_value_type('skin', skin, float, self.name)
validator.check_value_type('max_atom_in_grid_numbers', max_atom_in_grid_numbers, int, self.name)
validator.check_value_type('nxy', nxy, int, self.name)
validator.check_value_type('excluded_atom_numbers', excluded_atom_numbers, int, self.name)
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
validator.check_value_type('half_skin_square', half_skin_square, float, self.name)

View File

@ -830,7 +830,7 @@ class NeighborListRefresh(PrimitiveWithInfer):
refresh_interval (int32): the number of iteration steps between two updates of neighbor list. Default: 20.
cutoff (float32): the cutoff distance for short-range force calculation. Default: 10.0.
skin (float32): the maximum value of the distance atom allowed to move. Default: 2.0.
max_atom_in_grid_numbers (int32): the maximum number of atoms in one grid m. Default: 64.
max_atom_in_grid_numbers (int32): the maximum number of atoms in one grid k. Default: 64.
max_neighbor_numbers (int32): The maximum number of neighbors m. Default: 800.
forced_update (int32): the flag that decides whether to force an update. Default: 0.
forced_check (int32): the flag that decides whether to force an check. Default: 0.
@ -839,7 +839,7 @@ class NeighborListRefresh(PrimitiveWithInfer):
- **atom_numbers_in_grid_bucket** (Tensor) - The number of atoms in each grid bucket.
The data type is int32 and the shape is :math:`(G,)`.
- **bucket** (Tensor) - (Tensor) - The atom indices in each grid bucket.
The data type is int32 and the shape is :math:`(G, m)`.
The data type is int32 and the shape is :math:`(G, k)`.
- **crd** (Tensor) - The coordinates of each atom.
The data type is float32 and the shape is :math:`(n, 3)`.
- **box_length** (Tensor) - The box length of the simulation box.
@ -872,8 +872,8 @@ class NeighborListRefresh(PrimitiveWithInfer):
The data type is int32 and the shape is :math:`(n,)`.
- **need_refresh_flag** (Tensor) - Whether the neighbor list of each atom need update or not.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Tensor) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`.
- **refresh_count** (Union[Tensor, Scalar]) - Count how many iteration steps have passed since last update.
The data type is int32 and the shape is :math:`(1,)`, or refresh_count is a scalar with no shape.
Outputs:
- **res** (Tensor) - The return value after updating successfully.
@ -913,6 +913,7 @@ class NeighborListRefresh(PrimitiveWithInfer):
validator.check_value_type('cutoff', cutoff, float, self.name)
validator.check_value_type('skin', skin, float, self.name)
validator.check_value_type('max_atom_in_grid_numbers', max_atom_in_grid_numbers, int, self.name)
validator.check_value_type('nxy', nxy, int, self.name)
validator.check_value_type('excluded_atom_numbers', excluded_atom_numbers, int, self.name)
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
validator.check_value_type('half_skin_square', half_skin_square, float, self.name)