add vmap model ensembling doc and ut

This commit is contained in:
Erpim 2022-09-16 10:58:18 +08:00
parent 2a2c5eb86d
commit 1d237ec6e1
3 changed files with 91 additions and 4 deletions

View File

@ -17,10 +17,10 @@ mindspore.ops.vmap
- 当在vmap作用域内调用随机数生成方法时每次在向量函数之间生成相同的随机数。如果希望每个向量分支使用不同的随机数需要提前从外部生成一批随机数然后将其传入vmap。
参数:
- **fn** (Union[Cell, Function]) - 待沿参数轴映射的函数该函数至少拥有一个输入参数并且返回值为一个或多个Tensor或Tensor支持的数据类型。
- **fn** (Union[Cell, Function, CellList]) - 待沿参数轴映射的函数该函数至少拥有一个输入参数并且返回值为一个或多个Tensor或Tensor支持的数据类型。`fn` 的类型是CellList时为模型集成场景需要确保每个单元的结构相同并且单元数量与映射轴索引对应的size `axis_size` )一致。
- **in_axes** (Union[int, list, tuple]) - 指定输入参数映射的轴索引。如果 `in_axes` 是一个整数,则 `fn` 的所有输入参数都将根据此轴索引进行映射。
如果 `in_axes` 是一个tuple或list仅支持由整数或None组成则其长度应与 `fn` 的输入参数的个数一致,分别表示相应位置参数的映射轴索引。
请注意,每个参数对应的整数轴索引的取值范围必须在 :math:`[-ndim, ndim)` 中,其中 `ndim` 是参数的维度。None表示不沿任何轴映射。并且 `in_axes` 中必须至少有一个位置参数的映射轴索引不为None。 所有参数的映射轴索引对应的size `axis_size` 必须相等。默认值0。
请注意,每个参数对应的整数轴索引的取值范围必须在 :math:`[-ndim, ndim)` 中,其中 `ndim` 是参数的维度。None表示不沿任何轴映射。并且 `in_axes` 中必须至少有一个位置参数的映射轴索引不为None。所有参数的映射轴索引对应的size `axis_size` 必须相等。默认值0。
- **out_axes** (Union[int, list, tuple]) - 指定映射轴呈现在输出中的索引位置。如果 `out_axes` 是一个整数,则 `fn` 的所有输出都根据此axis指定。
如果 `out_axes` 是一个tuple或list仅支持由整数或None组成其长度应与 `fn` 的输出个数相等。
请注意,每个输出对应的整数轴索引的取值范围必须在 :math:`[-ndim, ndim)` 中,其中 `ndim``vmap` 映射后的函数的输出的维度。

View File

@ -218,8 +218,10 @@ def vmap(fn, in_axes=0, out_axes=0):
you need to generate batch random numbers externally in advance and then transfer them to vmap.
Args:
fn (Union[Cell, Function]): Function to be mapped along the parameter axes, which takes at least one argument
and returns one or more Tensors or the type of data supported by the MindSpore Tensor.
fn (Union[Cell, Function, CellList]): Function to be mapped along the parameter axes, which takes at least one
argument and returns one or more Tensors or the type of data supported by the MindSpore Tensor. When it is
a CellList, the model ensembling scenarioa, it is need to ensure that the structure of each cell is the same
and the number of cells is consistent with the sizes of the mapped axes (`axis_size`).
in_axes (Union[int, list, tuple]): Specifies which dimensions (axes) of the inputs should be mapped over.
If `in_axes` is an integer, all arguments of `fn` are mapped over according to this axis index. If `in_axes`
is a tuple or list, which only composed of integers or Nones and the length should equal to the number of

View File

@ -20,6 +20,7 @@ import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops.functional import vmap
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE)
@ -219,3 +220,87 @@ def test_scalar_with_non_zero_axis():
with pytest.raises(RuntimeError) as ex:
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 1))(x_hat, y_hat, z_hat)
assert "The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." in str(ex.value)
class AssignNetWithTwoParams(nn.Cell):
def __init__(self):
super(AssignNetWithTwoParams, self).__init__()
self.assign = P.Assign()
self.ref_a = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_a')
self.ref_b = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_b')
def construct(self, replace_tensor):
out = self.assign(self.ref_a, replace_tensor)
out = self.ref_b + out
return out
class AssignNetWithSingleParam(nn.Cell):
def __init__(self):
super(AssignNetWithSingleParam, self).__init__()
self.assign = P.Assign()
self.ref_a = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_a')
def construct(self, replace_tensor):
out = self.assign(self.ref_a, replace_tensor)
return out
class AssignNetWithTwoArgus(nn.Cell):
def __init__(self):
super(AssignNetWithTwoArgus, self).__init__()
self.assign = P.Assign()
self.ref_a = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_a')
def construct(self, replace_tensor, x):
out = self.assign(self.ref_a, replace_tensor)
return out, x
def test_celllist_with_one_model():
"""
Feature: vmap model ensembling scenario
Description: The `fn` is a CellList with only one Model.
Expectation: throw RuntimeError:"In the model ensembling parallel training scenario ('VmapOperation'
arg0 is a 'CellList'), the size of 'CellList' must be greater than 1, but got 1."
"""
m1 = AssignNetWithSingleParam()
mm = nn.CellList([m1])
replace_tensor = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32)
with pytest.raises(RuntimeError) as ex:
vmap(mm, in_axes=0)(replace_tensor)
assert "In the model ensembling parallel training scenario ('VmapOperation' arg0 is a 'CellList'), " \
"the size of 'CellList' must be greater than 1, but got 1." in str(ex.value)
def test_celllist_with_inconsistent_inputs():
"""
Feature: vmap model ensembling scenario
Description: The `fn` is a CellList with two Model, but they have different input size.
Expectation: throw RuntimeError:"'VmapOperation' arg0 is a CellList, whose elements's inputs should be consistent."
"""
m1 = AssignNetWithSingleParam()
m2 = AssignNetWithTwoArgus()
mm = nn.CellList([m1, m2])
replace_tensor = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32)
with pytest.raises(RuntimeError) as ex:
vmap(mm, in_axes=0)(replace_tensor)
assert "'VmapOperation' arg0 is a CellList, whose elements's inputs should be consistent." in str(ex.value)
def test_celllist_with_inconsistent_params():
"""
Feature: vmap model ensembling scenario
Description: The `fn` is a CellList with two Model, but they have different parameter size.
Expectation: throw ValueError:"Parameter size of each cell should be consistent, but get 1 and 2."
"""
m1 = AssignNetWithSingleParam()
m2 = AssignNetWithTwoParams()
mm = nn.CellList([m1, m2])
replace_tensor = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float32)
with pytest.raises(ValueError) as ex:
vmap(mm, in_axes=0)(replace_tensor)
assert "Parameter size of each cell should be consistent, but get 1 and 2." in str(ex.value)