forked from mindspore-Ecosystem/mindspore
add cell.apply
This commit is contained in:
parent
dcb6f6e5c8
commit
d871738bb3
|
@ -27,6 +27,16 @@
|
|||
参数:
|
||||
- **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
|
||||
|
||||
.. py:method:: apply(fn)
|
||||
|
||||
递归地将 `fn` 应用于每个子Cell(由 `.cells()` 返回)以及自身。通常用于初始化模型的参数。
|
||||
|
||||
参数:
|
||||
- **fn** (function) - 被执行于每个Cell的function。
|
||||
|
||||
返回:
|
||||
Cell类型,Cell本身。
|
||||
|
||||
.. py:method:: auto_cast_inputs(inputs)
|
||||
|
||||
在混合精度下,自动对输入进行类型转换。
|
||||
|
|
|
@ -1414,6 +1414,38 @@ class Cell(Cell_):
|
|||
if "fp32" in flags and flags.get("fp32", False):
|
||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
||||
|
||||
def apply(self, fn):
|
||||
"""
|
||||
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
|
||||
Typical use includes initializing the parameters of a model.
|
||||
|
||||
Args:
|
||||
fn (function) – function to be applied to each subcell
|
||||
|
||||
Returns:
|
||||
Cell, self.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.common.initializer import initializer, One
|
||||
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||
>>> def func(cell):
|
||||
... if isinstance(cell, nn.Dense):
|
||||
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||
>>> net.apply(func)
|
||||
SequentialCell<
|
||||
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
|
||||
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
|
||||
>
|
||||
>>> print(net[0].weight.asnumpy())
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
"""
|
||||
for cell in self.cells():
|
||||
cell.apply(fn)
|
||||
fn(self)
|
||||
return self
|
||||
|
||||
def add_flags(self, **flags):
|
||||
"""
|
||||
Add customized attributes for cell.
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.common.initializer import initializer, One
|
||||
|
||||
|
||||
class ModA(nn.Cell):
|
||||
|
@ -266,6 +267,25 @@ def test_add_attr():
|
|||
ModAddCellError(ta)
|
||||
|
||||
|
||||
def test_apply():
|
||||
"""
|
||||
Feature: Cell.apply.
|
||||
Description: Verify Cell.apply.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||
|
||||
def func(cell):
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||
|
||||
net.apply(func)
|
||||
|
||||
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
|
||||
assert np.allclose(target, net[0].weight.asnumpy())
|
||||
assert np.allclose(target, net[1].weight.asnumpy())
|
||||
|
||||
|
||||
def test_train_eval():
|
||||
m = nn.Cell()
|
||||
assert not m.training
|
||||
|
@ -305,9 +325,6 @@ def test_cell_names():
|
|||
|
||||
|
||||
class TestKwargsNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestKwargsNet, self).__init__()
|
||||
|
||||
def construct(self, p1, p2, p3=False, p4=False):
|
||||
if p3:
|
||||
return p1
|
||||
|
@ -315,6 +332,7 @@ class TestKwargsNet(nn.Cell):
|
|||
return P.Add()(p1, p2)
|
||||
return p2
|
||||
|
||||
|
||||
def test_kwargs_default_value1():
|
||||
"""
|
||||
Feature: Supports Cell kwargs inputs.
|
||||
|
@ -334,7 +352,6 @@ def test_kwargs_default_value2():
|
|||
Description: Pass kwargs.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
# Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2))
|
||||
x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
|
||||
nn_op = nn.ResizeBilinear()
|
||||
res = nn_op(x, (4, 4), align_corners=True)
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer, One
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
|
||||
|
@ -276,6 +278,25 @@ def test_add_attr():
|
|||
ModAddCellError(ta)
|
||||
|
||||
|
||||
def test_apply():
|
||||
"""
|
||||
Feature: Cell.apply.
|
||||
Description: Verify Cell.apply.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||
|
||||
def func(cell):
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||
|
||||
net.apply(func)
|
||||
|
||||
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
|
||||
assert np.allclose(target, net[0].weight.asnumpy())
|
||||
assert np.allclose(target, net[1].weight.asnumpy())
|
||||
|
||||
|
||||
def test_train_eval():
|
||||
""" test_train_eval """
|
||||
m = nn.Cell()
|
||||
|
|
Loading…
Reference in New Issue