forked from mindspore-Ecosystem/mindspore
add cell.apply
This commit is contained in:
parent
dcb6f6e5c8
commit
d871738bb3
|
@ -27,6 +27,16 @@
|
||||||
参数:
|
参数:
|
||||||
- **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
|
- **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
|
||||||
|
|
||||||
|
.. py:method:: apply(fn)
|
||||||
|
|
||||||
|
递归地将 `fn` 应用于每个子Cell(由 `.cells()` 返回)以及自身。通常用于初始化模型的参数。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **fn** (function) - 被执行于每个Cell的function。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Cell类型,Cell本身。
|
||||||
|
|
||||||
.. py:method:: auto_cast_inputs(inputs)
|
.. py:method:: auto_cast_inputs(inputs)
|
||||||
|
|
||||||
在混合精度下,自动对输入进行类型转换。
|
在混合精度下,自动对输入进行类型转换。
|
||||||
|
|
|
@ -1414,6 +1414,38 @@ class Cell(Cell_):
|
||||||
if "fp32" in flags and flags.get("fp32", False):
|
if "fp32" in flags and flags.get("fp32", False):
|
||||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
||||||
|
|
||||||
|
def apply(self, fn):
|
||||||
|
"""
|
||||||
|
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
|
||||||
|
Typical use includes initializing the parameters of a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (function) – function to be applied to each subcell
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, self.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore.common.initializer import initializer, One
|
||||||
|
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||||
|
>>> def func(cell):
|
||||||
|
... if isinstance(cell, nn.Dense):
|
||||||
|
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||||
|
>>> net.apply(func)
|
||||||
|
SequentialCell<
|
||||||
|
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
|
||||||
|
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
|
||||||
|
>
|
||||||
|
>>> print(net[0].weight.asnumpy())
|
||||||
|
[[1. 1.]
|
||||||
|
[1. 1.]]
|
||||||
|
"""
|
||||||
|
for cell in self.cells():
|
||||||
|
cell.apply(fn)
|
||||||
|
fn(self)
|
||||||
|
return self
|
||||||
|
|
||||||
def add_flags(self, **flags):
|
def add_flags(self, **flags):
|
||||||
"""
|
"""
|
||||||
Add customized attributes for cell.
|
Add customized attributes for cell.
|
||||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.api import _cell_graph_executor
|
from mindspore.common.api import _cell_graph_executor
|
||||||
|
from mindspore.common.initializer import initializer, One
|
||||||
|
|
||||||
|
|
||||||
class ModA(nn.Cell):
|
class ModA(nn.Cell):
|
||||||
|
@ -266,6 +267,25 @@ def test_add_attr():
|
||||||
ModAddCellError(ta)
|
ModAddCellError(ta)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply():
|
||||||
|
"""
|
||||||
|
Feature: Cell.apply.
|
||||||
|
Description: Verify Cell.apply.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||||
|
|
||||||
|
def func(cell):
|
||||||
|
if isinstance(cell, nn.Dense):
|
||||||
|
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||||
|
|
||||||
|
net.apply(func)
|
||||||
|
|
||||||
|
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
|
||||||
|
assert np.allclose(target, net[0].weight.asnumpy())
|
||||||
|
assert np.allclose(target, net[1].weight.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_train_eval():
|
def test_train_eval():
|
||||||
m = nn.Cell()
|
m = nn.Cell()
|
||||||
assert not m.training
|
assert not m.training
|
||||||
|
@ -305,9 +325,6 @@ def test_cell_names():
|
||||||
|
|
||||||
|
|
||||||
class TestKwargsNet(nn.Cell):
|
class TestKwargsNet(nn.Cell):
|
||||||
def __init__(self):
|
|
||||||
super(TestKwargsNet, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, p1, p2, p3=False, p4=False):
|
def construct(self, p1, p2, p3=False, p4=False):
|
||||||
if p3:
|
if p3:
|
||||||
return p1
|
return p1
|
||||||
|
@ -315,6 +332,7 @@ class TestKwargsNet(nn.Cell):
|
||||||
return P.Add()(p1, p2)
|
return P.Add()(p1, p2)
|
||||||
return p2
|
return p2
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs_default_value1():
|
def test_kwargs_default_value1():
|
||||||
"""
|
"""
|
||||||
Feature: Supports Cell kwargs inputs.
|
Feature: Supports Cell kwargs inputs.
|
||||||
|
@ -334,7 +352,6 @@ def test_kwargs_default_value2():
|
||||||
Description: Pass kwargs.
|
Description: Pass kwargs.
|
||||||
Expectation: No exception.
|
Expectation: No exception.
|
||||||
"""
|
"""
|
||||||
# Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2))
|
|
||||||
x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
|
x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
|
||||||
nn_op = nn.ResizeBilinear()
|
nn_op = nn.ResizeBilinear()
|
||||||
res = nn_op(x, (4, 4), align_corners=True)
|
res = nn_op(x, (4, 4), align_corners=True)
|
||||||
|
|
|
@ -16,8 +16,10 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.common.initializer import initializer, One
|
||||||
from ...ut_filter import non_graph_engine
|
from ...ut_filter import non_graph_engine
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,6 +278,25 @@ def test_add_attr():
|
||||||
ModAddCellError(ta)
|
ModAddCellError(ta)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply():
|
||||||
|
"""
|
||||||
|
Feature: Cell.apply.
|
||||||
|
Description: Verify Cell.apply.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
||||||
|
|
||||||
|
def func(cell):
|
||||||
|
if isinstance(cell, nn.Dense):
|
||||||
|
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
||||||
|
|
||||||
|
net.apply(func)
|
||||||
|
|
||||||
|
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
|
||||||
|
assert np.allclose(target, net[0].weight.asnumpy())
|
||||||
|
assert np.allclose(target, net[1].weight.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_train_eval():
|
def test_train_eval():
|
||||||
""" test_train_eval """
|
""" test_train_eval """
|
||||||
m = nn.Cell()
|
m = nn.Cell()
|
||||||
|
|
Loading…
Reference in New Issue