Reduce data copy for Parameter init
This commit is contained in:
parent
37ae36866c
commit
ab3d34aa2f
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Parameter for cell."""
|
||||
from copy import copy
|
||||
import sys
|
||||
import numbers
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
|
@ -172,7 +173,8 @@ class Parameter(Tensor_):
|
|||
|
||||
def __new__(cls, default_input, *args, **kwargs):
|
||||
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
|
||||
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
|
||||
rc = sys.getrefcount(default_input)
|
||||
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc)
|
||||
new_type = Parameter._get_base_class(input_class)
|
||||
obj = input_class.__new__(new_type)
|
||||
input_class.__init__(obj, *class_init_args)
|
||||
|
@ -233,6 +235,16 @@ class Parameter(Tensor_):
|
|||
new_obj._inited_param = self._inited_param
|
||||
return new_obj
|
||||
|
||||
def __str__(self):
|
||||
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
|
||||
f'requires_grad={self.requires_grad})'
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __parameter__(self):
|
||||
"""For parse check."""
|
||||
|
||||
@staticmethod
|
||||
def _get_base_class(input_class):
|
||||
input_class_name = Parameter.__name__
|
||||
|
@ -251,12 +263,16 @@ class Parameter(Tensor_):
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_parameter_new_args(data):
|
||||
def _get_parameter_new_args(data, rc):
|
||||
"""Set `set_data` of current `Parameter`."""
|
||||
if isinstance(data, bool):
|
||||
raise ValueError('Parameter data can not be `bool`')
|
||||
if isinstance(data, Tensor):
|
||||
if not data.has_init:
|
||||
if rc == 4:
|
||||
# when ref count is 4, means the input data is not referenced
|
||||
# in other place, so we can make a Tensor without copy data.
|
||||
return (Tensor, data)
|
||||
# make a copy of Tensor to init the parameter.
|
||||
return (Tensor, data.asnumpy())
|
||||
|
||||
|
@ -272,16 +288,6 @@ class Parameter(Tensor_):
|
|||
return (Tensor, data, mstype.float32)
|
||||
return (Tensor, data)
|
||||
|
||||
def __str__(self):
|
||||
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
|
||||
f'requires_grad={self.requires_grad})'
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __parameter__(self):
|
||||
"""For parse check."""
|
||||
|
||||
def set_param_ps(self, init_in_server=False):
|
||||
"""
|
||||
Set whether the trainable parameter is updated by parameter server and whether the
|
||||
|
@ -522,7 +528,6 @@ class Parameter(Tensor_):
|
|||
raise TypeError("The argument `key` must be int type.")
|
||||
self.param_info.key = value
|
||||
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
"""
|
||||
|
|
|
@ -1687,8 +1687,6 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('col2im')(self, output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
|
||||
|
||||
def reshape(self, *shape):
|
||||
"""
|
||||
Give a new shape to a tensor without changing its data.
|
||||
|
@ -2790,7 +2788,7 @@ class Tensor(Tensor_):
|
|||
shape = self.shape
|
||||
|
||||
try:
|
||||
arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
|
||||
data = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
|
||||
except ValueError:
|
||||
msg = "Error shape={}".format(shape)
|
||||
logger.critical(msg)
|
||||
|
@ -2827,14 +2825,13 @@ class Tensor(Tensor_):
|
|||
self.init.seed, _ = self.seed
|
||||
|
||||
with seed_context(self.init):
|
||||
self.init(arr)
|
||||
data = np.array(arr)
|
||||
self.init(data)
|
||||
if opt_shard_group:
|
||||
rank = get_rank(opt_shard_group)
|
||||
size = get_group_size(opt_shard_group)
|
||||
data = np.split(data, size)[rank]
|
||||
self.init = None
|
||||
self.assign_value(Tensor(data, dtype=self.dtype))
|
||||
self.assign_value(Tensor_.from_numpy(data))
|
||||
return self
|
||||
|
||||
def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):
|
||||
|
|
Loading…
Reference in New Issue