add cell.apply

This commit is contained in:
fandawei 2022-12-26 16:01:05 +08:00
parent dcb6f6e5c8
commit d871738bb3
4 changed files with 84 additions and 4 deletions

View File

@ -27,6 +27,16 @@
参数: 参数:
- **flags** (dict) - Cell的配置信息目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值None。 - **flags** (dict) - Cell的配置信息目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值None。
.. py:method:: apply(fn)
递归地将 `fn` 应用于每个子Cell`.cells()` 返回)以及自身。通常用于初始化模型的参数。
参数:
- **fn** (function) - 被执行于每个Cell的function。
返回:
Cell类型Cell本身。
.. py:method:: auto_cast_inputs(inputs) .. py:method:: auto_cast_inputs(inputs)
在混合精度下,自动对输入进行类型转换。 在混合精度下,自动对输入进行类型转换。

View File

@ -1414,6 +1414,38 @@ class Cell(Cell_):
if "fp32" in flags and flags.get("fp32", False): if "fp32" in flags and flags.get("fp32", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32) self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
def apply(self, fn):
"""
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
Typical use includes initializing the parameters of a model.
Args:
fn (function) function to be applied to each subcell
Returns:
Cell, self.
Examples:
>>> import mindspore.nn as nn
>>> from mindspore.common.initializer import initializer, One
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
>>> def func(cell):
... if isinstance(cell, nn.Dense):
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
>>> net.apply(func)
SequentialCell<
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
>
>>> print(net[0].weight.asnumpy())
[[1. 1.]
[1. 1.]]
"""
for cell in self.cells():
cell.apply(fn)
fn(self)
return self
def add_flags(self, **flags): def add_flags(self, **flags):
""" """
Add customized attributes for cell. Add customized attributes for cell.

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.api import _cell_graph_executor from mindspore.common.api import _cell_graph_executor
from mindspore.common.initializer import initializer, One
class ModA(nn.Cell): class ModA(nn.Cell):
@ -266,6 +267,25 @@ def test_add_attr():
ModAddCellError(ta) ModAddCellError(ta)
def test_apply():
"""
Feature: Cell.apply.
Description: Verify Cell.apply.
Expectation: No exception.
"""
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
def func(cell):
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
net.apply(func)
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
assert np.allclose(target, net[0].weight.asnumpy())
assert np.allclose(target, net[1].weight.asnumpy())
def test_train_eval(): def test_train_eval():
m = nn.Cell() m = nn.Cell()
assert not m.training assert not m.training
@ -305,9 +325,6 @@ def test_cell_names():
class TestKwargsNet(nn.Cell): class TestKwargsNet(nn.Cell):
def __init__(self):
super(TestKwargsNet, self).__init__()
def construct(self, p1, p2, p3=False, p4=False): def construct(self, p1, p2, p3=False, p4=False):
if p3: if p3:
return p1 return p1
@ -315,6 +332,7 @@ class TestKwargsNet(nn.Cell):
return P.Add()(p1, p2) return P.Add()(p1, p2)
return p2 return p2
def test_kwargs_default_value1(): def test_kwargs_default_value1():
""" """
Feature: Supports Cell kwargs inputs. Feature: Supports Cell kwargs inputs.
@ -334,7 +352,6 @@ def test_kwargs_default_value2():
Description: Pass kwargs. Description: Pass kwargs.
Expectation: No exception. Expectation: No exception.
""" """
# Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2))
x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32) x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
nn_op = nn.ResizeBilinear() nn_op = nn.ResizeBilinear()
res = nn_op(x, (4, 4), align_corners=True) res = nn_op(x, (4, 4), align_corners=True)

View File

@ -16,8 +16,10 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer, One
from ...ut_filter import non_graph_engine from ...ut_filter import non_graph_engine
@ -276,6 +278,25 @@ def test_add_attr():
ModAddCellError(ta) ModAddCellError(ta)
def test_apply():
"""
Feature: Cell.apply.
Description: Verify Cell.apply.
Expectation: No exception.
"""
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
def func(cell):
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
net.apply(func)
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
assert np.allclose(target, net[0].weight.asnumpy())
assert np.allclose(target, net[1].weight.asnumpy())
def test_train_eval(): def test_train_eval():
""" test_train_eval """ """ test_train_eval """
m = nn.Cell() m = nn.Cell()