init parameter data by defaultOnly keep no data as MetaTensor in auto parallel mode
This commit is contained in:
parent
fe7141e9e0
commit
d4d6457ea7
|
@ -414,6 +414,10 @@ class _Executor:
|
|||
if auto_parallel_mode:
|
||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
||||
if not enable_debug_runtime or enable_ge:
|
||||
if auto_parallel_mode:
|
||||
obj.load_parameter_slice(None)
|
||||
|
||||
self._updata_param_node_default_input(phase, replace)
|
||||
|
||||
# set parallel inputs in sink mode
|
||||
|
|
|
@ -15,25 +15,22 @@
|
|||
|
||||
"""Parameter for cell."""
|
||||
from copy import copy
|
||||
from mindspore import context
|
||||
from .._c_expression import ParamValue
|
||||
from . import dtype as mstype
|
||||
from .initializer import initializer, Initializer
|
||||
from .tensor import Tensor, MetaTensor
|
||||
from .._checkparam import _check_str_by_regular
|
||||
from ..parallel._tensor import _get_slice_index
|
||||
from ..parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
__all__ = ['Parameter', 'ParameterTuple']
|
||||
|
||||
PARAMETER_NAME_DEFAULT = "Parameter"
|
||||
PARAMETER_NAME_PREFIX_MAX_LEN = 1024
|
||||
|
||||
|
||||
def _check_type(x):
|
||||
"""Check input data type"""
|
||||
if not isinstance(x, Parameter):
|
||||
raise ValueError("Should be `Parameter` collection.")
|
||||
return True
|
||||
def _is_in_parallel_mode():
|
||||
"""Get parallel mode."""
|
||||
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
|
||||
|
||||
|
||||
class Parameter(MetaTensor):
|
||||
|
@ -42,10 +39,10 @@ class Parameter(MetaTensor):
|
|||
|
||||
After initialized `Parameter` is a subtype of `Tensor`.
|
||||
|
||||
In graph mode, if init `Parameter` by a `Initializer`, the type of Parameter will be a `MetaTensor`
|
||||
not a `Tensor`. `MetaTensor` only save the shape type info of a tensor with no memory usage. The shape
|
||||
can be change while compile for auto-parallel. Call `init_data` will return a Tensor Parameter with
|
||||
initialized data.
|
||||
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
|
||||
a `Initializer`, the type of Parameter will be a `MetaTensor` not a `Tensor`. `MetaTensor`
|
||||
only save the shape type info of a tensor with no memory usage. The shape can be change while
|
||||
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
|
||||
|
||||
Note:
|
||||
Each parameter of Cell is represented by Parameter class.
|
||||
|
@ -67,7 +64,7 @@ class Parameter(MetaTensor):
|
|||
input_class.__init__(obj, *class_init_args)
|
||||
# it's better to make the Initializer a kind of metatensor.
|
||||
obj.init_mode = None
|
||||
if isinstance(default_input, Initializer):
|
||||
if not isinstance(obj, Tensor):
|
||||
obj.init_mode = default_input
|
||||
return obj
|
||||
|
||||
|
@ -112,11 +109,10 @@ class Parameter(MetaTensor):
|
|||
if isinstance(data, bool):
|
||||
raise ValueError('Parameter data can not be `bool`')
|
||||
if isinstance(data, Initializer):
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
# always init data while in pynative mode.
|
||||
data = data.to_tensor()
|
||||
else:
|
||||
if _is_in_parallel_mode():
|
||||
# do not init data while in auto parallel.
|
||||
return (MetaTensor, data.dtype, data.shape)
|
||||
data = data.to_tensor()
|
||||
if isinstance(data, Tensor):
|
||||
# make a copy of Tensor to init the parameter
|
||||
return (Tensor, data.asnumpy(),)
|
||||
|
@ -127,9 +123,9 @@ class Parameter(MetaTensor):
|
|||
return (Tensor, data)
|
||||
|
||||
def __str__(self):
|
||||
value_str = MetaTensor.__repr__(self)
|
||||
value_str = MetaTensor.__str__(self)
|
||||
if isinstance(self, Tensor):
|
||||
value_str = Tensor.__repr__(self)
|
||||
value_str = Tensor.__str__(self)
|
||||
return f'Parameter (name={self._value.name}, value={value_str})'
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -235,8 +231,6 @@ class Parameter(MetaTensor):
|
|||
shape = self.shape
|
||||
dtype = self.dtype
|
||||
x.default_input = initializer(init, shape=shape, dtype=dtype)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
x.init_data()
|
||||
return x
|
||||
|
||||
@property
|
||||
|
@ -381,8 +375,12 @@ class ParameterTuple(tuple):
|
|||
"""
|
||||
def __new__(cls, iterable):
|
||||
"""Create instance object of ParameterTuple."""
|
||||
g = (x for x in iterable if _check_type(x))
|
||||
return tuple.__new__(ParameterTuple, g)
|
||||
data = tuple(iterable)
|
||||
for x in data:
|
||||
if not isinstance(x, Parameter):
|
||||
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
|
||||
f"But got a {type(iterable)}, {iterable}")
|
||||
return tuple.__new__(ParameterTuple, tuple(data))
|
||||
|
||||
def clone(self, prefix, init='same'):
|
||||
"""
|
||||
|
|
|
@ -120,7 +120,6 @@ def test_yolov3():
|
|||
|
||||
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
||||
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
||||
init_net_param(net)
|
||||
|
||||
total_epoch_size = 60
|
||||
lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
|
||||
|
|
|
@ -207,7 +207,6 @@ def test_bert_percision():
|
|||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
callback = ModelCallback()
|
||||
netwithloss.init_parameters_data()
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
value = param.default_input
|
||||
|
@ -279,7 +278,6 @@ def test_bert_performance():
|
|||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
callback = ModelCallback()
|
||||
netwithloss.init_parameters_data()
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
value = param.default_input
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore import context, Tensor, Parameter, ParameterTuple
|
||||
from mindspore._checkparam import _check_str_by_regular
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -43,11 +43,11 @@ def test_parameter_tuple_illegal():
|
|||
ParameterTuple(ptuple)
|
||||
with pytest.raises(TypeError):
|
||||
ParameterTuple(p1)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
ParameterTuple(plist2)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
ParameterTuple(ptuple_str)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
ParameterTuple(pstr)
|
||||
with pytest.raises(TypeError):
|
||||
ParameterTuple(pnum)
|
||||
|
@ -136,6 +136,9 @@ def test_check_str_by_regular():
|
|||
_check_str_by_regular(str6)
|
||||
|
||||
def test_parameter_lazy_init():
|
||||
# support lazy init in SEMI_AUTO_PARALLEL mode
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
|
||||
# Call init_data() without set default_input.
|
||||
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
|
||||
assert not isinstance(para.default_input, Tensor)
|
||||
|
@ -167,3 +170,4 @@ def test_parameter_lazy_init():
|
|||
# expect no effect.
|
||||
para.init_data()
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
|
@ -71,7 +71,7 @@ def test_group_lr():
|
|||
assert opt.dynamic_lr is False
|
||||
assert opt.is_group_params_ordered is True
|
||||
for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy())
|
||||
else:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy())
|
||||
|
@ -103,7 +103,7 @@ def test_group_dynamic_1():
|
|||
assert opt.dynamic_lr is True
|
||||
assert opt.is_group_params_ordered is True
|
||||
for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.learning_rate.data.asnumpy() == \
|
||||
Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
|
||||
else:
|
||||
|
@ -135,7 +135,7 @@ def test_group_dynamic_2():
|
|||
assert opt.is_group is True
|
||||
assert opt.dynamic_lr is True
|
||||
for lr, param in zip(opt.learning_rate, opt.parameters):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.learning_rate.data.asnumpy() == \
|
||||
Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy())
|
||||
else:
|
||||
|
@ -203,7 +203,7 @@ def test_weight_decay():
|
|||
assert opt.is_group_params_ordered is True
|
||||
for weight_decay, decay_flags, param, order_param in zip(
|
||||
opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert weight_decay == conv_weight_decay
|
||||
assert decay_flags is True
|
||||
else:
|
||||
|
@ -303,11 +303,11 @@ def test_order_params_1():
|
|||
assert opt.is_group_params_ordered is True
|
||||
for weight_decay, decay_flags, lr, param, order_param in zip(
|
||||
opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, bias_params+conv_params):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy())
|
||||
assert weight_decay == 0.01
|
||||
assert decay_flags is True
|
||||
elif param in bias_params:
|
||||
elif 'bias' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(0.01, mstype.float32).asnumpy())
|
||||
assert weight_decay == 0.0
|
||||
assert decay_flags is False
|
||||
|
@ -342,11 +342,11 @@ def test_order_params_2():
|
|||
all_lr = opt.get_lr_parameter(fc1_params+conv_params)
|
||||
for weight_decay, decay_flags, lr, param, order_param in zip(
|
||||
opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params):
|
||||
if param in conv_params:
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy())
|
||||
assert weight_decay == conv_weight_decay
|
||||
assert decay_flags is True
|
||||
elif param in fc1_params:
|
||||
elif 'fc1' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy())
|
||||
assert weight_decay == default_wd
|
||||
assert decay_flags is False
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.dtype import get_py_obj_dtype
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -25,6 +25,9 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops.composite import grad_all_with_sens
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
def mul(x, y):
|
||||
return x * y
|
||||
|
|
Loading…
Reference in New Issue